diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 43f8e7c3eba..beb5857bbe8 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -69,10 +69,12 @@ class DenseNet(nn.Module): """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): + num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, transform_input=False): super(DenseNet, self).__init__() + self.transform_input = transform_input + # First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), @@ -110,6 +112,14 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), nn.init.constant_(m.bias, 0) def forward(self, x): + + # imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + features = self.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) @@ -123,6 +133,8 @@ def densenet121(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) @@ -150,6 +162,8 @@ def densenet169(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) @@ -177,6 +191,8 @@ def densenet201(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) @@ -204,6 +220,8 @@ def densenet161(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* """ model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)