diff --git a/super_resolution/model.py b/super_resolution/model.py index 4e28ba3063..7a7de30efa 100644 --- a/super_resolution/model.py +++ b/super_resolution/model.py @@ -24,7 +24,7 @@ def forward(self, x): return x def _initialize_weights(self): - init.orthogonal(self.conv1.weight, init.gain('relu')) - init.orthogonal(self.conv2.weight, init.gain('relu')) - init.orthogonal(self.conv3.weight, init.gain('relu')) + init.orthogonal(self.conv1.weight, init.calculate_gain('relu')) + init.orthogonal(self.conv2.weight, init.calculate_gain('relu')) + init.orthogonal(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal(self.conv4.weight)