diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 270364a2e25..ddeb7d3150d 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,6 +1,7 @@ import math import torch import torch.nn as nn +import torch.nn.init as init import torch.utils.model_zoo as model_zoo @@ -87,13 +88,10 @@ def __init__(self, version=1.0, num_classes=1000): for m in self.modules(): if isinstance(m, nn.Conv2d): - gain = 2.0 if m is final_conv: - m.weight.data.normal_(0, 0.01) + init.normal(m.weight.data, mean=0.0, std=0.01) else: - fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels - u = math.sqrt(3.0 * gain / fan_in) - m.weight.data.uniform_(-u, u) + init.kaiming_uniform(m.weight.data) if m.bias is not None: m.bias.data.zero_()