Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 6 additions & 88 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

import pytest
import test_models as TM
import torch
import torchvision
from common_utils import cpu_and_gpu, needs_cuda
from torchvision.models._api import WeightsEnum, Weights
from torchvision.models._utils import handle_legacy_interface
from torchvision.prototype import models

run_if_test_with_prototype = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
Expand Down Expand Up @@ -76,9 +73,9 @@ def test_get_weight(name, weight):
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(torchvision.models.detection)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
+ TM.get_models_from_module(torchvision.models.segmentation)
+ TM.get_models_from_module(torchvision.models.video)
+ TM.get_models_from_module(torchvision.models.optical_flow),
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
Expand All @@ -92,9 +89,9 @@ def test_naming_conventions(model_fn):
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(torchvision.models.detection)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
+ TM.get_models_from_module(torchvision.models.segmentation)
+ TM.get_models_from_module(torchvision.models.video)
+ TM.get_models_from_module(torchvision.models.optical_flow),
)
@run_if_test_with_prototype
def test_schema_meta_validation(model_fn):
Expand Down Expand Up @@ -143,85 +140,6 @@ def test_schema_meta_validation(model_fn):
assert not bad_names


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_segmentation_model(model_fn, dev):
TM.test_segmentation_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_video_model(model_fn, dev):
TM.test_video_model(model_fn, dev)


@needs_cuda
@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
@run_if_test_with_prototype
def test_raft(model_builder, scripted):
TM.test_raft(model_builder, scripted)


@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_old_vs_new_factory(model_fn, dev):
defaults = {
"models": {
"input_shape": (1, 3, 224, 224),
},
"detection": {
"input_shape": (3, 300, 300),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
"quantize": True,
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 3, 4, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
},
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
x = torch.rand(input_shape).to(device=dev)
if module_name == "detection":
x = [x]

if module_name == "optical_flow":
args = [x, x] # RAFT model requires img1, img2 as input
else:
args = [x]

# compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)


def test_smoke():
import torchvision.prototype.models # noqa: F401


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/optical_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .raft import RAFT, raft_large, raft_small
from .raft import *
191 changes: 158 additions & 33 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand All @@ -8,24 +8,22 @@
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops import Conv2dNormActivation

from ..._internally_replaced_utils import load_state_dict_from_url
from ...transforms import OpticalFlowEval, InterpolationMode
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._utils import handle_legacy_interface
from ._utils import grid_sample, make_coords_grid, upsample_flow


__all__ = (
"RAFT",
"raft_large",
"raft_small",
"Raft_Large_Weights",
"Raft_Small_Weights",
)


_MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
}


class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""

Expand Down Expand Up @@ -500,10 +498,139 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
return flow_predictions


_COMMON_META = {
"task": "optical_flow",
"architecture": "RAFT",
"publication_year": 2020,
"interpolation": InterpolationMode.BILINEAR,
}


class Raft_Large_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411,
"sintel_train_finalpass_epe": 2.7894,
"kitti_train_per_image_epe": 5.0172,
"kitti_train_f1-all": 17.4506,
},
)

C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161,
"kitti_train_per_image_epe": 4.5118,
"kitti_train_f1-all": 16.0679,
},
)

C_T_SKHT_V1 = Weights(
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_test_cleanpass_epe": 1.94,
"sintel_test_finalpass_epe": 3.18,
},
)

C_T_SKHT_V2 = Weights(
# Chairs + Things + Sintel fine-tuning, i.e.:
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_test_cleanpass_epe": 1.819,
"sintel_test_finalpass_epe": 3.067,
},
)

C_T_SKHT_K_V1 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"kitti_test_f1-all": 5.10,
},
)

C_T_SKHT_K_V2 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
# Same as CT_SKHT with extra fine-tuning on Kitti
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"kitti_test_f1-all": 5.19,
},
)

DEFAULT = C_T_SKHT_V2


class Raft_Small_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-small.pth)
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 2.1231,
"sintel_train_finalpass_epe": 3.2790,
"kitti_train_per_image_epe": 7.6557,
"kitti_train_f1-all": 25.2801,
},
)
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.9901,
"sintel_train_finalpass_epe": 3.2831,
"kitti_train_per_image_epe": 7.5978,
"kitti_train_f1-all": 25.2369,
},
)

DEFAULT = C_T_V2


def _raft(
*,
arch=None,
pretrained=False,
weights=None,
progress=False,
# Feature encoder
feature_encoder_layers,
Expand Down Expand Up @@ -577,38 +704,34 @@ def _raft(
mask_predictor=mask_predictor,
**kwargs, # not really needed, all params should be consumed by now
)
if pretrained:
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


def raft_large(*, pretrained=False, progress=True, **kwargs):
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Please see the example below for a tutorial on how to use this model.

Args:
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`
with two fine-tuning steps:

- one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D`
- one on :class:`~torchvsion.datasets.KittiFlow`.

This corresponds to the ``C+T+S/K`` strategy in the paper.

progress (bool): If True, displays a progress bar of the download to stderr.
weights(Raft_Large_weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.
RAFT: The model.
"""

weights = Raft_Large_Weights.verify(weights)

return _raft(
arch="raft_large",
pretrained=pretrained,
weights=weights,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
Expand Down Expand Up @@ -637,25 +760,27 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
)


def raft_small(*, pretrained=False, progress=True, **kwargs):
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Please see the example below for a tutorial on how to use this model.

Args:
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`.
weights(Raft_Small_weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.
RAFT: The model.

"""
weights = Raft_Small_Weights.verify(weights)

return _raft(
arch="raft_small",
pretrained=pretrained,
weights=weights,
progress=progress,
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .fcn import *
from .deeplabv3 import *
from .fcn import *
from .lraspp import *
Loading