diff --git a/test/test_models.py b/test/test_models.py index 7a8e1d83b6e..0455573c2a0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -339,6 +339,18 @@ def get_gn(num_channels): self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) + def test_inceptionv3_eval(self): + # replacement for models.inception_v3(pretrained=True) that does not download weights + kwargs = {} + kwargs['transform_input'] = True + kwargs['aux_logits'] = True + kwargs['init_weights'] = False + model = models.Inception3(**kwargs) + model.aux_logits = False + model.AuxLogits = None + m = torch.jit.script(model.eval()) + self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299)) + def test_fasterrcnn_double(self): model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model.double() diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 9e49e446849..f2156018327 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -55,7 +55,7 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False - del model.AuxLogits + model.AuxLogits = None return model return Inception3(**kwargs) @@ -108,6 +108,7 @@ def __init__( self.Mixed_6c = inception_c(768, channels_7x7=160) self.Mixed_6d = inception_c(768, channels_7x7=160) self.Mixed_6e = inception_c(768, channels_7x7=192) + self.AuxLogits: Optional[nn.Module] = None if aux_logits: self.AuxLogits = inception_aux(768, num_classes) self.Mixed_7a = inception_d(768) @@ -170,11 +171,10 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: # N x 768 x 17 x 17 x = self.Mixed_6e(x) # N x 768 x 17 x 17 - aux_defined = self.training and self.aux_logits - if aux_defined: - aux = self.AuxLogits(x) - else: - aux = None + aux = torch.jit.annotate(Optional[Tensor], None) + if self.AuxLogits is not None: + if self.training: + aux = self.AuxLogits(x) # N x 768 x 17 x 17 x = self.Mixed_7a(x) # N x 1280 x 8 x 8 diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index f452de02815..cdc72ce4851 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -67,7 +67,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): if quantize: if not original_aux_logits: model.aux_logits = False - del model.AuxLogits + model.AuxLogits = None model_url = quant_model_urls['inception_v3_google' + '_' + backend] else: model_url = inception_module.model_urls['inception_v3_google'] @@ -80,7 +80,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): if not quantize: if not original_aux_logits: model.aux_logits = False - del model.AuxLogits + model.AuxLogits = None return model