From 17fb5d0b94f14d71d2975e992a2563abca78581f Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 29 Jul 2021 18:56:53 +0200 Subject: [PATCH 1/7] style: Added typing annotations to segmentation/_utils --- torchvision/models/segmentation/_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 176b7490038..144720e28db 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -1,19 +1,26 @@ from collections import OrderedDict +from typing import Optional, Dict -from torch import nn +from torch import nn, Tensor from torch.nn import functional as F +from .._utils import IntermediateLayerGetter class _SimpleSegmentationModel(nn.Module): __constants__ = ['aux_classifier'] - def __init__(self, backbone, classifier, aux_classifier=None): + def __init__( + self, + backbone: IntermediateLayerGetter, + classifier: nn.Module, + aux_classifier: Optional[nn.Module] = None + ) -> None: super(_SimpleSegmentationModel, self).__init__() self.backbone = backbone self.classifier = classifier self.aux_classifier = aux_classifier - def forward(self, x): + def forward(self, x: Tensor) -> Dict[str, Tensor]: input_shape = x.shape[-2:] # contract: features is a dict of tensors features = self.backbone(x) From dc05b90efa6a3f3a49fd5645ef531511a1a038d9 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 29 Jul 2021 18:57:02 +0200 Subject: [PATCH 2/7] style: Added typing annotations to segmentation/segmentation --- .../models/segmentation/segmentation.py | 78 +++++++++++++++---- 1 file changed, 63 insertions(+), 15 deletions(-) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 7b3a0258ddb..189fb0dc8ed 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,3 +1,6 @@ +from torch import nn +from typing import Any +from ._utils import _SimpleSegmentationModel from .._utils import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 @@ -22,7 +25,13 @@ } -def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True): +def _segm_model( + name: str, + backbone_name: str, + num_classes: int, + aux: bool, + pretrained_backbone: bool = True +) -> _SimpleSegmentationModel: if 'resnet' in backbone_name: backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, @@ -66,7 +75,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True) return model -def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): +def _load_model( + arch_type: str, + backbone: str, + pretrained: bool, + progress: bool, + num_classes: int, + aux_loss: bool, + **kwargs: Any +) -> _SimpleSegmentationModel: if pretrained: aux_loss = True kwargs["pretrained_backbone"] = False @@ -76,7 +93,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss return model -def _load_weights(model, arch_type, backbone, progress): +def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: arch = arch_type + '_' + backbone + '_coco' model_url = model_urls.get(arch, None) if model_url is None: @@ -86,7 +103,7 @@ def _load_weights(model, arch_type, backbone, progress): model.load_state_dict(state_dict) -def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True): +def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -103,8 +120,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru return model -def fcn_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: @@ -117,8 +139,13 @@ def fcn_resnet50(pretrained=False, progress=True, return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) -def fcn_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: @@ -131,8 +158,13 @@ def fcn_resnet101(pretrained=False, progress=True, return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: @@ -145,8 +177,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True, return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: @@ -159,8 +196,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True, return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> DeepLabV3: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: @@ -173,7 +215,13 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) -def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): +def lraspp_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: bool = False, + **kwargs: Any +) -> LRASPP: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: From 70e4b5dd7bd0602ce01c94df93b26e42711129c7 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 29 Jul 2021 19:03:32 +0200 Subject: [PATCH 3/7] style: Added typing annotations to remaining segmentation models --- torchvision/models/segmentation/deeplabv3.py | 13 +++++++------ torchvision/models/segmentation/fcn.py | 2 +- torchvision/models/segmentation/lraspp.py | 19 ++++++++++++++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 7acc013ccb1..42f79bc6e59 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,6 +1,7 @@ import torch from torch import nn from torch.nn import functional as F +from typing import List from ._utils import _SimpleSegmentationModel @@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel): class DeepLabHead(nn.Sequential): - def __init__(self, in_channels, num_classes): + def __init__(self, in_channels: int, num_classes: int) -> None: super(DeepLabHead, self).__init__( ASPP(in_channels, [12, 24, 36]), nn.Conv2d(256, 256, 3, padding=1, bias=False), @@ -38,7 +39,7 @@ def __init__(self, in_channels, num_classes): class ASPPConv(nn.Sequential): - def __init__(self, in_channels, out_channels, dilation): + def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), @@ -48,14 +49,14 @@ def __init__(self, in_channels, out_channels, dilation): class ASPPPooling(nn.Sequential): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int) -> None: super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: size = x.shape[-2:] for mod in self: x = mod(x) @@ -63,7 +64,7 @@ def forward(self, x): class ASPP(nn.Module): - def __init__(self, in_channels, atrous_rates, out_channels=256): + def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: super(ASPP, self).__init__() modules = [] modules.append(nn.Sequential( @@ -85,7 +86,7 @@ def __init__(self, in_channels, atrous_rates, out_channels=256): nn.ReLU(), nn.Dropout(0.5)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: res = [] for conv in self.convs: res.append(conv(x)) diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 3c695b53167..9c8db1e1211 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel): class FCNHead(nn.Sequential): - def __init__(self, in_channels, channels): + def __init__(self, in_channels: int, channels: int) -> None: inter_channels = in_channels // 4 layers = [ nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 44cd9b1e773..0e5fb5ee898 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -24,12 +24,19 @@ class LRASPP(nn.Module): inter_channels (int, optional): the number of channels for intermediate computations. """ - def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128): + def __init__( + self, + backbone: nn.Module, + low_channels: int, + high_channels: int, + num_classes: int, + inter_channels: int = 128 + ) -> None: super().__init__() self.backbone = backbone self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) - def forward(self, input): + def forward(self, input: Tensor) -> Dict[str, Tensor]: features = self.backbone(input) out = self.classifier(features) out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) @@ -42,7 +49,13 @@ def forward(self, input): class LRASPPHead(nn.Module): - def __init__(self, low_channels, high_channels, num_classes, inter_channels): + def __init__( + self, + low_channels: int, + high_channels: int, + num_classes: int, + inter_channels: int + ) -> None: super().__init__() self.cbr = nn.Sequential( nn.Conv2d(high_channels, inter_channels, 1, bias=False), From e8f14f803f31b0a904b430d344b9febbcf99b203 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 29 Jul 2021 19:05:44 +0200 Subject: [PATCH 4/7] style: Fixed typing of DeepLab --- torchvision/models/segmentation/deeplabv3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 42f79bc6e59..15ab8846e7d 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -87,8 +87,8 @@ def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int nn.Dropout(0.5)) def forward(self, x: torch.Tensor) -> torch.Tensor: - res = [] + _res = [] for conv in self.convs: - res.append(conv(x)) - res = torch.cat(res, dim=1) + _res.append(conv(x)) + res = torch.cat(_res, dim=1) return self.project(res) From 9dd0ce74cff662b26603321ba1e99e1665e6fc58 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 23 Aug 2021 10:13:04 +0200 Subject: [PATCH 5/7] style: Fixed typing --- torchvision/models/segmentation/_utils.py | 3 +- .../models/segmentation/segmentation.py | 35 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 144720e28db..fb94b9b1528 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -3,7 +3,6 @@ from torch import nn, Tensor from torch.nn import functional as F -from .._utils import IntermediateLayerGetter class _SimpleSegmentationModel(nn.Module): @@ -11,7 +10,7 @@ class _SimpleSegmentationModel(nn.Module): def __init__( self, - backbone: IntermediateLayerGetter, + backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None ) -> None: diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 7e035f4035b..e687ed77dc3 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,6 +1,5 @@ from torch import nn -from typing import Any -from ._utils import _SimpleSegmentationModel +from typing import Any, Optional from .._utils import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 @@ -29,9 +28,9 @@ def _segm_model( name: str, backbone_name: str, num_classes: int, - aux: bool, + aux: Optional[bool] = None, pretrained_backbone: bool = True -) -> _SimpleSegmentationModel: +) -> nn.Module: if 'resnet' in backbone_name: backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, @@ -81,9 +80,9 @@ def _load_model( pretrained: bool, progress: bool, num_classes: int, - aux_loss: bool, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> _SimpleSegmentationModel: +) -> nn.Module: if pretrained: aux_loss = True kwargs["pretrained_backbone"] = False @@ -124,9 +123,9 @@ def fcn_resnet50( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> FCN: +) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: @@ -143,9 +142,9 @@ def fcn_resnet101( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> FCN: +) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: @@ -162,9 +161,9 @@ def deeplabv3_resnet50( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> DeepLabV3: +) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: @@ -181,9 +180,9 @@ def deeplabv3_resnet101( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> DeepLabV3: +) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: @@ -200,9 +199,9 @@ def deeplabv3_mobilenet_v3_large( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> DeepLabV3: +) -> nn.Module: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: @@ -219,9 +218,9 @@ def lraspp_mobilenet_v3_large( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: bool = False, + aux_loss: Optional[bool] = None, **kwargs: Any -) -> LRASPP: +) -> nn.Module: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: From 09c92bb7e46afa04bc047a73969f8b6ffef47ce0 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 23 Aug 2021 11:42:14 +0200 Subject: [PATCH 6/7] fix: Fixed typing annotations & default values --- torchvision/models/segmentation/segmentation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index e687ed77dc3..dc443ca31ec 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -28,7 +28,7 @@ def _segm_model( name: str, backbone_name: str, num_classes: int, - aux: Optional[bool] = None, + aux: bool, pretrained_backbone: bool = True ) -> nn.Module: if 'resnet' in backbone_name: @@ -80,7 +80,7 @@ def _load_model( pretrained: bool, progress: bool, num_classes: int, - aux_loss: Optional[bool] = None, + aux_loss: bool, **kwargs: Any ) -> nn.Module: if pretrained: @@ -218,7 +218,6 @@ def lraspp_mobilenet_v3_large( pretrained: bool = False, progress: bool = True, num_classes: int = 21, - aux_loss: Optional[bool] = None, **kwargs: Any ) -> nn.Module: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. From 9eaddd957d023bfbd7b06ffd73cce1fa5ef61498 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 23 Aug 2021 10:59:42 +0100 Subject: [PATCH 7/7] Fixing python_type_check --- torchvision/models/segmentation/segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index dc443ca31ec..938965e330b 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -28,7 +28,7 @@ def _segm_model( name: str, backbone_name: str, num_classes: int, - aux: bool, + aux: Optional[bool], pretrained_backbone: bool = True ) -> nn.Module: if 'resnet' in backbone_name: @@ -80,7 +80,7 @@ def _load_model( pretrained: bool, progress: bool, num_classes: int, - aux_loss: bool, + aux_loss: Optional[bool], **kwargs: Any ) -> nn.Module: if pretrained: