diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 176b7490038..fb94b9b1528 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -1,19 +1,25 @@ from collections import OrderedDict +from typing import Optional, Dict -from torch import nn +from torch import nn, Tensor from torch.nn import functional as F class _SimpleSegmentationModel(nn.Module): __constants__ = ['aux_classifier'] - def __init__(self, backbone, classifier, aux_classifier=None): + def __init__( + self, + backbone: nn.Module, + classifier: nn.Module, + aux_classifier: Optional[nn.Module] = None + ) -> None: super(_SimpleSegmentationModel, self).__init__() self.backbone = backbone self.classifier = classifier self.aux_classifier = aux_classifier - def forward(self, x): + def forward(self, x: Tensor) -> Dict[str, Tensor]: input_shape = x.shape[-2:] # contract: features is a dict of tensors features = self.backbone(x) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 7acc013ccb1..15ab8846e7d 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,6 +1,7 @@ import torch from torch import nn from torch.nn import functional as F +from typing import List from ._utils import _SimpleSegmentationModel @@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel): class DeepLabHead(nn.Sequential): - def __init__(self, in_channels, num_classes): + def __init__(self, in_channels: int, num_classes: int) -> None: super(DeepLabHead, self).__init__( ASPP(in_channels, [12, 24, 36]), nn.Conv2d(256, 256, 3, padding=1, bias=False), @@ -38,7 +39,7 @@ def __init__(self, in_channels, num_classes): class ASPPConv(nn.Sequential): - def __init__(self, in_channels, out_channels, dilation): + def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), @@ -48,14 +49,14 @@ def __init__(self, in_channels, out_channels, dilation): class ASPPPooling(nn.Sequential): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int) -> None: super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: size = x.shape[-2:] for mod in self: x = mod(x) @@ -63,7 +64,7 @@ def forward(self, x): class ASPP(nn.Module): - def __init__(self, in_channels, atrous_rates, out_channels=256): + def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: super(ASPP, self).__init__() modules = [] modules.append(nn.Sequential( @@ -85,9 +86,9 @@ def __init__(self, in_channels, atrous_rates, out_channels=256): nn.ReLU(), nn.Dropout(0.5)) - def forward(self, x): - res = [] + def forward(self, x: torch.Tensor) -> torch.Tensor: + _res = [] for conv in self.convs: - res.append(conv(x)) - res = torch.cat(res, dim=1) + _res.append(conv(x)) + res = torch.cat(_res, dim=1) return self.project(res) diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 3c695b53167..9c8db1e1211 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel): class FCNHead(nn.Sequential): - def __init__(self, in_channels, channels): + def __init__(self, in_channels: int, channels: int) -> None: inter_channels = in_channels // 4 layers = [ nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 44cd9b1e773..0e5fb5ee898 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -24,12 +24,19 @@ class LRASPP(nn.Module): inter_channels (int, optional): the number of channels for intermediate computations. """ - def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128): + def __init__( + self, + backbone: nn.Module, + low_channels: int, + high_channels: int, + num_classes: int, + inter_channels: int = 128 + ) -> None: super().__init__() self.backbone = backbone self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) - def forward(self, input): + def forward(self, input: Tensor) -> Dict[str, Tensor]: features = self.backbone(input) out = self.classifier(features) out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) @@ -42,7 +49,13 @@ def forward(self, input): class LRASPPHead(nn.Module): - def __init__(self, low_channels, high_channels, num_classes, inter_channels): + def __init__( + self, + low_channels: int, + high_channels: int, + num_classes: int, + inter_channels: int + ) -> None: super().__init__() self.cbr = nn.Sequential( nn.Conv2d(high_channels, inter_channels, 1, bias=False), diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 0f2f14c97ba..938965e330b 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,3 +1,5 @@ +from torch import nn +from typing import Any, Optional from .._utils import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 @@ -22,7 +24,13 @@ } -def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True): +def _segm_model( + name: str, + backbone_name: str, + num_classes: int, + aux: Optional[bool], + pretrained_backbone: bool = True +) -> nn.Module: if 'resnet' in backbone_name: backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, @@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True) return model -def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): +def _load_model( + arch_type: str, + backbone: str, + pretrained: bool, + progress: bool, + num_classes: int, + aux_loss: Optional[bool], + **kwargs: Any +) -> nn.Module: if pretrained: aux_loss = True kwargs["pretrained_backbone"] = False @@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss return model -def _load_weights(model, arch_type, backbone, progress): +def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: arch = arch_type + '_' + backbone + '_coco' model_url = model_urls.get(arch, None) if model_url is None: @@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress): model.load_state_dict(state_dict) -def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True): +def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru return model -def fcn_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any +) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: @@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True, return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) -def fcn_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any +) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: @@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True, return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any +) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: @@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True, return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any +) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: @@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True, return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any +) -> nn.Module: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: @@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) -def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): +def lraspp_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + **kwargs: Any +) -> nn.Module: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: