From 0506b5363178c9ebb7e8fcb73011871c34b0ed0b Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 00:55:33 +0530 Subject: [PATCH 1/3] Internal Imagenet normalisation for pretrained alexnet model Consistent with pytorch inceptionV3 implementation --- torchvision/models/alexnet.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 95d8dce0f28..90068dc9e1f 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -12,8 +12,9 @@ class AlexNet(nn.Module): - def __init__(self, num_classes=1000): + def __init__(self, num_classes=1000, transform_input=False): super(AlexNet, self).__init__() + self.transform_input = transform_input self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), @@ -41,6 +42,14 @@ def __init__(self, num_classes=1000): ) def forward(self, x): + + #imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.features(x) x = self.avgpool(x) x = x.view(x.size(0), 256 * 6 * 6) From 3c87ffcabd2bb354c4b2e4ff36c4d786ba1bdaf6 Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:29:21 +0530 Subject: [PATCH 2/3] fixed flakes8 --- torchvision/models/alexnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 90068dc9e1f..a9241032585 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo @@ -42,8 +43,8 @@ def __init__(self, num_classes=1000, transform_input=False): ) def forward(self, x): - - #imagenet normalisation + + # imagenet normalisation if self.transform_input: x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 From 9ac4bc7b5825a2ba14b174f505808ff1edf007dc Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 02:22:11 +0530 Subject: [PATCH 3/3] Update alexnet.py --- torchvision/models/alexnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index a9241032585..7d71930fae7 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -50,7 +50,7 @@ def forward(self, x): x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 x = torch.cat((x_ch0, x_ch1, x_ch2), 1) - + x = self.features(x) x = self.avgpool(x) x = x.view(x.size(0), 256 * 6 * 6)