diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 95d8dce0f28..7d71930fae7 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo @@ -12,8 +13,9 @@ class AlexNet(nn.Module): - def __init__(self, num_classes=1000): + def __init__(self, num_classes=1000, transform_input=False): super(AlexNet, self).__init__() + self.transform_input = transform_input self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), @@ -41,6 +43,14 @@ def __init__(self, num_classes=1000): ) def forward(self, x): + + # imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.features(x) x = self.avgpool(x) x = x.view(x.size(0), 256 * 6 * 6)