From 1fb3b5fef342bb7cd64d4a7087bf587bca2376a1 Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:09:17 +0530 Subject: [PATCH 1/2] Internal Imagenet normalisation for pretrained squeezenet models Makes it easier to normalise the model weights with imagenet mean and std when using transfer learning. Useful for beginners who forget to normalise the input images while using transfer learning. --- torchvision/models/squeezenet.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index a3e51e3b953..90c31a9a51e 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -38,8 +38,9 @@ def forward(self, x): class SqueezeNet(nn.Module): - def __init__(self, version=1.0, num_classes=1000): + def __init__(self, version=1.0, num_classes=1000, transform_input=False): super(SqueezeNet, self).__init__() + self.transform_input = transform_input if version not in [1.0, 1.1]: raise ValueError("Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)) @@ -95,6 +96,14 @@ def __init__(self, version=1.0, num_classes=1000): init.constant_(m.bias, 0) def forward(self, x): + + #imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.features(x) x = self.classifier(x) return x.view(x.size(0), self.num_classes) From d0c358561e56dea7574c0fb149c4ddc72670e3e5 Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:32:19 +0530 Subject: [PATCH 2/2] fixed E265 --- torchvision/models/squeezenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 90c31a9a51e..89316a59c3e 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -97,7 +97,7 @@ def __init__(self, version=1.0, num_classes=1000, transform_input=False): def forward(self, x): - #imagenet normalisation + # imagenet normalisation if self.transform_input: x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224