diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 7e7f27bbde0..9c49944a73a 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -118,7 +118,8 @@ def forward(self, x): # 8 x 8 x 2048 x = self.Mixed_7c(x) # 8 x 8 x 2048 - x = F.avg_pool2d(x, kernel_size=8) + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) # 1 x 1 x 2048 x = F.dropout(x, training=self.training) # 1 x 1 x 2048 @@ -311,6 +312,9 @@ def forward(self, x): # 5 x 5 x 128 x = self.conv1(x) # 1 x 1 x 768 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # 1 x 1 x 768 x = x.view(x.size(0), -1) # 768 x = self.fc(x)