diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 06baea35fa8..65b8ffb9e40 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -3,12 +3,9 @@ import pytest import test_models as TM -import torch import torchvision -from common_utils import cpu_and_gpu, needs_cuda from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface -from torchvision.prototype import models run_if_test_with_prototype = pytest.mark.skipif( os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", @@ -76,9 +73,9 @@ def test_get_weight(name, weight): TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + + TM.get_models_from_module(torchvision.models.segmentation) + + TM.get_models_from_module(torchvision.models.video) + + TM.get_models_from_module(torchvision.models.optical_flow), ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) @@ -92,9 +89,9 @@ def test_naming_conventions(model_fn): TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + + TM.get_models_from_module(torchvision.models.segmentation) + + TM.get_models_from_module(torchvision.models.video) + + TM.get_models_from_module(torchvision.models.optical_flow), ) @run_if_test_with_prototype def test_schema_meta_validation(model_fn): @@ -143,85 +140,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_segmentation_model(model_fn, dev): - TM.test_segmentation_model(model_fn, dev) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_video_model(model_fn, dev): - TM.test_video_model(model_fn, dev) - - -@needs_cuda -@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow)) -@pytest.mark.parametrize("scripted", (False, True)) -@run_if_test_with_prototype -def test_raft(model_builder, scripted): - TM.test_raft(model_builder, scripted) - - -@pytest.mark.parametrize( - "model_fn", - TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), -) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_old_vs_new_factory(model_fn, dev): - defaults = { - "models": { - "input_shape": (1, 3, 224, 224), - }, - "detection": { - "input_shape": (3, 300, 300), - }, - "quantization": { - "input_shape": (1, 3, 224, 224), - "quantize": True, - }, - "segmentation": { - "input_shape": (1, 3, 520, 520), - }, - "video": { - "input_shape": (1, 3, 4, 112, 112), - }, - "optical_flow": { - "input_shape": (1, 3, 128, 128), - }, - } - model_name = model_fn.__name__ - module_name = model_fn.__module__.split(".")[-2] - kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} - input_shape = kwargs.pop("input_shape") - kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models - x = torch.rand(input_shape).to(device=dev) - if module_name == "detection": - x = [x] - - if module_name == "optical_flow": - args = [x, x] # RAFT model requires img1, img2 as input - else: - args = [x] - - # compare with new model builder parameterized in the old fashion way - try: - model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) - model_new = _build_model(model_fn, **kwargs).to(device=dev) - except ModuleNotFoundError: - pytest.skip(f"Model '{model_name}' not available in both modules.") - torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False) - - -def test_smoke(): - import torchvision.prototype.models # noqa: F401 - - # With this filter, every unexpected warning will be turned into an error @pytest.mark.filterwarnings("error") class TestHandleLegacyInterface: diff --git a/torchvision/models/optical_flow/__init__.py b/torchvision/models/optical_flow/__init__.py index 9dd32f25dec..89d2302f825 100644 --- a/torchvision/models/optical_flow/__init__.py +++ b/torchvision/models/optical_flow/__init__.py @@ -1 +1 @@ -from .raft import RAFT, raft_large, raft_small +from .raft import * diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 4dfd232d499..95469f78d3c 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -8,8 +8,10 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import OpticalFlowEval, InterpolationMode from ...utils import _log_api_usage_once +from .._api import Weights, WeightsEnum +from .._utils import handle_legacy_interface from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -17,15 +19,11 @@ "RAFT", "raft_large", "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", ) -_MODELS_URLS = { - "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", -} - - class ResidualBlock(nn.Module): """Slightly modified Residual block with extra relu and biases.""" @@ -500,10 +498,139 @@ def forward(self, image1, image2, num_flow_updates: int = 12): return flow_predictions +_COMMON_META = { + "task": "optical_flow", + "architecture": "RAFT", + "publication_year": 2020, + "interpolation": InterpolationMode.BILINEAR, +} + + +class Raft_Large_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-things.pth) + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.4411, + "sintel_train_finalpass_epe": 2.7894, + "kitti_train_per_image_epe": 5.0172, + "kitti_train_f1-all": 17.4506, + }, + ) + + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.3822, + "sintel_train_finalpass_epe": 2.7161, + "kitti_train_per_image_epe": 4.5118, + "kitti_train_f1-all": 16.0679, + }, + ) + + C_T_SKHT_V1 = Weights( + # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_test_cleanpass_epe": 1.94, + "sintel_test_finalpass_epe": 3.18, + }, + ) + + C_T_SKHT_V2 = Weights( + # Chairs + Things + Sintel fine-tuning, i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_test_cleanpass_epe": 1.819, + "sintel_test_finalpass_epe": 3.067, + }, + ) + + C_T_SKHT_K_V1 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "kitti_test_f1-all": 5.10, + }, + ) + + C_T_SKHT_K_V2 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # Same as CT_SKHT with extra fine-tuning on Kitti + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "kitti_test_f1-all": 5.19, + }, + ) + + DEFAULT = C_T_SKHT_V2 + + +class Raft_Small_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-small.pth) + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 2.1231, + "sintel_train_finalpass_epe": 3.2790, + "kitti_train_per_image_epe": 7.6557, + "kitti_train_f1-all": 25.2801, + }, + ) + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.9901, + "sintel_train_finalpass_epe": 3.2831, + "kitti_train_per_image_epe": 7.5978, + "kitti_train_f1-all": 25.2369, + }, + ) + + DEFAULT = C_T_V2 + + def _raft( *, - arch=None, - pretrained=False, + weights=None, progress=False, # Feature encoder feature_encoder_layers, @@ -577,38 +704,34 @@ def _raft( mask_predictor=mask_predictor, **kwargs, # not really needed, all params should be consumed by now ) - if pretrained: - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def raft_large(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` - with two fine-tuning steps: - - - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` - - one on :class:`~torchvsion.datasets.KittiFlow`. - - This corresponds to the ``C+T+S/K`` strategy in the paper. - - progress (bool): If True, displays a progress bar of the download to stderr. + weights(Raft_Large_weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Large_Weights.verify(weights) + return _raft( - arch="raft_large", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), @@ -637,25 +760,27 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): ) -def raft_small(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. + weights(Raft_Small_weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Small_Weights.verify(weights) return _raft( - arch="raft_small", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index 1765502d693..3d6f37f958a 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,3 +1,3 @@ -from .fcn import * from .deeplabv3 import * +from .fcn import * from .lraspp import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 0bbea5d3e81..44a60a95c54 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -4,7 +4,6 @@ from torch import nn, Tensor from torch.nn import functional as F -from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -36,10 +35,3 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: result["aux"] = x return result - - -def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 15ab5fffa5e..6e8bf0c398b 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,31 +1,31 @@ -from typing import List, Optional +from functools import partial +from typing import Any, List, Optional import torch from torch import nn from torch.nn import functional as F -from .. import mobilenetv3 -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms import SemanticSegmentationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from ._utils import _SimpleSegmentationModel from .fcn import FCNHead __all__ = [ "DeepLabV3", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", + "deeplabv3_mobilenet_v3_large", "deeplabv3_resnet50", "deeplabv3_resnet101", - "deeplabv3_mobilenet_v3_large", ] -model_urls = { - "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", -} - - class DeepLabV3(_SimpleSegmentationModel): """ Implements DeepLabV3 model from @@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _deeplabv3_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -128,8 +128,62 @@ def _deeplabv3_resnet( return DeepLabV3(backbone, classifier, aux_classifier) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "DeepLabV3", + "publication_year": 2017, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class DeepLabV3_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 42004074, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "mIoU": 66.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 60996202, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "mIoU": 67.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 11029328, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "mIoU": 60.3, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _deeplabv3_mobilenetv3( - backbone: mobilenetv3.MobileNetV3, + backbone: MobileNetV3, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -151,91 +205,124 @@ def _deeplabv3_mobilenetv3( return DeepLabV3(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet50( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet101( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): The number of classes aux_loss (bool, optional): If True, include an auxiliary classifier - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def deeplabv3_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 307781ebf00..5a3ca1f654f 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -1,19 +1,17 @@ -from typing import Optional +from functools import partial +from typing import Any, Optional from torch import nn -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms import SemanticSegmentationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from ._utils import _SimpleSegmentationModel -__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"] - - -model_urls = { - "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", -} +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] class FCN(_SimpleSegmentationModel): @@ -49,8 +47,47 @@ def __init__(self, in_channels: int, channels: int) -> None: super().__init__(*layers) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "FCN", + "publication_year": 2014, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class FCN_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 35322218, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", + "mIoU": 60.5, + "acc": 91.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class FCN_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 54314346, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", + "mIoU": 63.7, + "acc": 91.9, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _fcn_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> FCN: @@ -64,61 +101,83 @@ def _fcn_resnet( return FCN(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcn_resnet50( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = None, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def fcn_resnet101( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = None, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index ca73140661b..d1fe15a350d 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,21 +1,19 @@ from collections import OrderedDict -from typing import Any, Dict +from functools import partial +from typing import Any, Dict, Optional from torch import nn, Tensor from torch.nn import functional as F +from ...transforms import SemanticSegmentationEval, InterpolationMode from ...utils import _log_api_usage_once -from .. import mobilenetv3 -from .._utils import IntermediateLayerGetter -from ._utils import _load_weights +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large -__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"] - - -model_urls = { - "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", -} +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] class LRASPP(nn.Module): @@ -30,7 +28,7 @@ class LRASPP(nn.Module): "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. - num_classes (int): number of output classes of the model (including the background). + num_classes (int, optional): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """ @@ -81,7 +79,7 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: return self.low_classifier(low) + self.high_classifier(x) -def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP: +def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. @@ -95,31 +93,61 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> return LRASPP(backbone, low_channels, high_channels, num_classes) +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + "task": "image_semantic_segmentation", + "architecture": "LRASPP", + "publication_year": 2019, + "num_params": 3221538, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", + "mIoU": 57.9, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def lraspp_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> LRASPP: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (LRASPP_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, the backbone will be pre-trained. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ if kwargs.pop("aux_loss", False): raise NotImplementedError("This model does not use auxiliary loss") - if pretrained: - pretrained_backbone = False - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _lraspp_mobilenetv3(backbone, num_classes) - if pretrained: - arch = "lraspp_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 4ac781a7c4c..a6b779d10f1 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,18 +1,25 @@ +from functools import partial from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union import torch.nn as nn from torch import Tensor -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import VideoClassificationEval, InterpolationMode from ...utils import _log_api_usage_once +from .._api import WeightsEnum, Weights +from .._meta import _KINETICS400_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"] -model_urls = { - "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", -} +__all__ = [ + "VideoResNet", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", + "r3d_18", + "mc3_18", + "r2plus1d_18", +] class Conv3DSimple(nn.Conv3d): @@ -281,80 +288,152 @@ def _make_layer( return nn.Sequential(*layers) -def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: - model = VideoResNet(**kwargs) +def _video_resnet( + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VideoResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = VideoResNet(block, conv_makers, layers, stem, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) return model -def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +_COMMON_META = { + "task": "video_classification", + "publication_year": 2017, + "size": (112, 112), + "min_size": (1, 1), + "categories": _KINETICS400_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", +} + + +class R3D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R3D", + "num_params": 33371472, + "acc@1": 52.75, + "acc@5": 75.45, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MC3_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "MC3", + "num_params": 11695440, + "acc@1": 53.90, + "acc@5": 76.29, + }, + ) + DEFAULT = KINETICS400_V1 + + +class R2Plus1D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R(2+1)D", + "num_params": 31505325, + "acc@1": 57.50, + "acc@5": 78.81, + }, + ) + DEFAULT = KINETICS400_V1 + + +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R3D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R3D-18 network + VideoResNet: R3D-18 network """ + weights = R3D_18_Weights.verify(weights) return _video_resnet( - "r3d_18", - pretrained, + BasicBlock, + [Conv3DSimple] * 4, + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] * 4, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (MC3_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: MC3 Network definition + VideoResNet: MC3 Network definition """ + weights = MC3_18_Weights.verify(weights) + return _video_resnet( - "mc3_18", - pretrained, + BasicBlock, + [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R(2+1)D-18 network + VideoResNet: R(2+1)D-18 network """ + weights = R2Plus1D_18_Weights.verify(weights) + return _video_resnet( - "r2plus1d_18", - pretrained, + BasicBlock, + [Conv2Plus1D] * 4, + [2, 2, 2, 2], + R2Plus1dStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv2Plus1D] * 4, - layers=[2, 2, 2, 2], - stem=R2Plus1dStem, **kwargs, ) diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index e1be6c81f59..bd35d31dcfd 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,5 +1,4 @@ from . import datasets from . import features -from . import models from . import transforms from . import utils diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py deleted file mode 100644 index 3d7baca6284..00000000000 --- a/torchvision/prototype/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import optical_flow -from . import segmentation -from . import video diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py deleted file mode 100644 index 9b78f70b768..00000000000 --- a/torchvision/prototype/models/optical_flow/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py deleted file mode 100644 index 33e3243c2a0..00000000000 --- a/torchvision/prototype/models/optical_flow/raft.py +++ /dev/null @@ -1,249 +0,0 @@ -from typing import Optional - -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.models._api import Weights -from torchvision.models._api import WeightsEnum -from torchvision.models._utils import handle_legacy_interface -from torchvision.models.optical_flow import RAFT -from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.transforms import OpticalFlowEval, InterpolationMode - - -__all__ = ( - "RAFT", - "raft_large", - "raft_small", - "Raft_Large_Weights", - "Raft_Small_Weights", -) - - -_COMMON_META = { - "task": "optical_flow", - "architecture": "RAFT", - "publication_year": 2020, - "interpolation": InterpolationMode.BILINEAR, -} - - -class Raft_Large_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-things.pth) - url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 1.4411, - "sintel_train_finalpass_epe": 2.7894, - "kitti_train_per_image_epe": 5.0172, - "kitti_train_f1-all": 17.4506, - }, - ) - - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.3822, - "sintel_train_finalpass_epe": 2.7161, - "kitti_train_per_image_epe": 4.5118, - "kitti_train_f1-all": 16.0679, - }, - ) - - C_T_SKHT_V1 = Weights( - # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_test_cleanpass_epe": 1.94, - "sintel_test_finalpass_epe": 3.18, - }, - ) - - C_T_SKHT_V2 = Weights( - # Chairs + Things + Sintel fine-tuning, i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_test_cleanpass_epe": 1.819, - "sintel_test_finalpass_epe": 3.067, - }, - ) - - C_T_SKHT_K_V1 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "kitti_test_f1-all": 5.10, - }, - ) - - C_T_SKHT_K_V2 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # Same as CT_SKHT with extra fine-tuning on Kitti - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "kitti_test_f1-all": 5.19, - }, - ) - - DEFAULT = C_T_SKHT_V2 - - -class Raft_Small_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-small.pth) - url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 2.1231, - "sintel_train_finalpass_epe": 3.2790, - "kitti_train_per_image_epe": 7.6557, - "kitti_train_f1-all": 25.2801, - }, - ) - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.9901, - "sintel_train_finalpass_epe": 3.2831, - "kitti_train_per_image_epe": 7.5978, - "kitti_train_f1-all": 25.2369, - }, - ) - - DEFAULT = C_T_V2 - - -@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) -def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): - """RAFT model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Large_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - """ - - weights = Raft_Large_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_block=ResidualBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(64, 64, 96, 128, 256), - context_encoder_block=ResidualBlock, - context_encoder_norm_layer=BatchNorm2d, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=4, - # Motion encoder - motion_encoder_corr_layers=(256, 192), - motion_encoder_flow_layers=(128, 64), - motion_encoder_out_channels=128, - # Recurrent block - recurrent_block_hidden_state_size=128, - recurrent_block_kernel_size=((1, 5), (5, 1)), - recurrent_block_padding=((0, 2), (2, 0)), - # Flow head - flow_head_hidden_size=256, - # Mask predictor - use_mask_predictor=True, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) -def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): - """RAFT "small" model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Small_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - - """ - - weights = Raft_Small_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(32, 32, 64, 96, 128), - feature_encoder_block=BottleneckBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(32, 32, 64, 96, 160), - context_encoder_block=BottleneckBlock, - context_encoder_norm_layer=None, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=3, - # Motion encoder - motion_encoder_corr_layers=(96,), - motion_encoder_flow_layers=(64, 32), - motion_encoder_out_channels=82, - # Recurrent block - recurrent_block_hidden_state_size=96, - recurrent_block_kernel_size=(3,), - recurrent_block_padding=(1,), - # Flow head - flow_head_hidden_size=128, - # Mask predictor - use_mask_predictor=False, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model diff --git a/torchvision/prototype/models/segmentation/__init__.py b/torchvision/prototype/models/segmentation/__init__.py deleted file mode 100644 index 20273be2170..00000000000 --- a/torchvision/prototype/models/segmentation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fcn import * -from .lraspp import * -from .deeplabv3 import * diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py deleted file mode 100644 index 2c8d7f6ad84..00000000000 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ /dev/null @@ -1,171 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.resnet import resnet50, resnet101, ResNet50_Weights, ResNet101_Weights -from torchvision.models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = [ - "DeepLabV3", - "DeepLabV3_ResNet50_Weights", - "DeepLabV3_ResNet101_Weights", - "DeepLabV3_MobileNet_V3_Large_Weights", - "deeplabv3_mobilenet_v3_large", - "deeplabv3_resnet50", - "deeplabv3_resnet101", -] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "DeepLabV3", - "publication_year": 2017, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class DeepLabV3_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 42004074, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", - "mIoU": 66.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 60996202, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", - "mIoU": 67.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 11029328, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", - "mIoU": 60.3, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet50( - *, - weights: Optional[DeepLabV3_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet101( - *, - weights: Optional[DeepLabV3_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def deeplabv3_mobilenet_v3_large( - *, - weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py deleted file mode 100644 index e7b12621940..00000000000 --- a/torchvision/prototype/models/segmentation/fcn.py +++ /dev/null @@ -1,115 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 -from torchvision.models.segmentation.fcn import FCN, _fcn_resnet -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "FCN", - "publication_year": 2014, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FCN_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 35322218, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", - "mIoU": 60.5, - "acc": 91.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class FCN_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 54314346, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", - "mIoU": 63.7, - "acc": 91.9, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcn_resnet50( - *, - weights: Optional[FCN_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def fcn_resnet101( - *, - weights: Optional[FCN_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py deleted file mode 100644 index 21c15373089..00000000000 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ /dev/null @@ -1,64 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] - - -class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - "task": "image_semantic_segmentation", - "architecture": "LRASPP", - "publication_year": 2019, - "num_params": 3221538, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", - "mIoU": 57.9, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def lraspp_mobilenet_v3_large( - *, - weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> LRASPP: - if kwargs.pop("aux_loss", False): - raise NotImplementedError("This model does not use auxiliary loss") - - weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _lraspp_mobilenetv3(backbone, num_classes) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/video/__init__.py b/torchvision/prototype/models/video/__init__.py deleted file mode 100644 index b792ca6ecf7..00000000000 --- a/torchvision/prototype/models/video/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .resnet import * diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py deleted file mode 100644 index 0f4c0dd1dc9..00000000000 --- a/torchvision/prototype/models/video/resnet.py +++ /dev/null @@ -1,150 +0,0 @@ -from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Type, Union - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _KINETICS400_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.video.resnet import ( - BasicBlock, - BasicStem, - Bottleneck, - Conv2Plus1D, - Conv3DSimple, - Conv3DNoTemporal, - R2Plus1dStem, - VideoResNet, -) -from torchvision.transforms import VideoClassificationEval, InterpolationMode - - -__all__ = [ - "VideoResNet", - "R3D_18_Weights", - "MC3_18_Weights", - "R2Plus1D_18_Weights", - "r3d_18", - "mc3_18", - "r2plus1d_18", -] - - -def _video_resnet( - block: Type[Union[BasicBlock, Bottleneck]], - conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], - layers: List[int], - stem: Callable[..., nn.Module], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VideoResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VideoResNet(block, conv_makers, layers, stem, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "video_classification", - "publication_year": 2017, - "size": (112, 112), - "min_size": (1, 1), - "categories": _KINETICS400_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", -} - - -class R3D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R3D", - "num_params": 33371472, - "acc@1": 52.75, - "acc@5": 75.45, - }, - ) - DEFAULT = KINETICS400_V1 - - -class MC3_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "MC3", - "num_params": 11695440, - "acc@1": 53.90, - "acc@5": 76.29, - }, - ) - DEFAULT = KINETICS400_V1 - - -class R2Plus1D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R(2+1)D", - "num_params": 31505325, - "acc@1": 57.50, - "acc@5": 78.81, - }, - ) - DEFAULT = KINETICS400_V1 - - -@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) -def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R3D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] * 4, - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) -def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = MC3_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) -def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R2Plus1D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv2Plus1D] * 4, - [2, 2, 2, 2], - R2Plus1dStem, - weights, - progress, - **kwargs, - )