diff --git a/docs/source/conf.py b/docs/source/conf.py index 014eb3c3ae9..3e1b5c95a7b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -347,6 +347,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): metrics = meta.pop("_metrics") for dataset, dataset_metrics in metrics.items(): for metric_name, metric_value in dataset_metrics.items(): + metric_name = metric_name.replace("_", "-") table.append((f"{metric_name} (on {dataset})", str(metric_value))) for k, v in meta.items(): diff --git a/docs/source/models/raft.rst b/docs/source/models/raft.rst new file mode 100644 index 00000000000..7ea477698b4 --- /dev/null +++ b/docs/source/models/raft.rst @@ -0,0 +1,25 @@ +RAFT +==== + +.. currentmodule:: torchvision.models.optical_flow + +The RAFT model is based on the `RAFT: Recurrent All-Pairs Field Transforms for +Optical Flow `__ paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate a RAFT model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.optical_flow.RAFT`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + raft_large + raft_small diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index 94ffe733538..8d6306139fe 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -376,6 +376,7 @@ Box MAPs are reported on COCO val2017: .. include:: generated/detection_table.rst + Instance Segmentation --------------------- @@ -481,6 +482,18 @@ Accuracies are reported on Kinetics-400 using single crops for clip length 16: .. include:: generated/video_table.rst +Optical Flow +============ + +.. currentmodule:: torchvision.models.optical_flow + +The following Optical Flow models are available, with or without pre-trained + +.. toctree:: + :maxdepth: 1 + + models/raft + Using models from Hub ===================== diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index b382906517d..65a40fa1927 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -517,6 +517,19 @@ def forward(self, image1, image2, num_flow_updates: int = 12): class Raft_Large_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + C_T_V1 = Weights( # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", @@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum): "Sintel-Train-Finalpass": {"epe": 2.7894}, "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506}, }, - "_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", }, ) @@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum): "Sintel-Train-Finalpass": {"epe": 2.7161}, "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679}, }, - "_docs": """These weights were trained from scratch on Chairs + Things.""", + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", }, ) @@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum): "Sintel-Test-Finalpass": {"epe": 3.18}, }, "_docs": """ - These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on - Sintel (C+T+S+K+H). + These weights were ported from the original paper. They are + trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on + Sintel. The Sintel fine-tuning step is a combination of + :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). """, }, ) @@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum): "Sintel-Test-Finalpass": {"epe": 3.067}, }, "_docs": """ - These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H). + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and then + fine-tuned on Sintel. The Sintel fine-tuning step is a + combination of :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). """, }, ) @@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum): "Kitti-Test": {"fl_all": 5.10}, }, "_docs": """ - These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on - Sintel and then on Kitti. + These weights were ported from the original paper. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. """, }, ) @@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum): "Kitti-Test": {"fl_all": 5.19}, }, "_docs": """ - These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti. + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. """, }, ) @@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + C_T_V1 = Weights( # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", @@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum): "Sintel-Train-Finalpass": {"epe": 3.2790}, "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801}, }, - "_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", }, ) C_T_V2 = Weights( @@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum): "Sintel-Train-Finalpass": {"epe": 3.2831}, "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369}, }, - "_docs": """These weights were trained from scratch on Chairs + Things.""", + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", }, ) @@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * Please see the example below for a tutorial on how to use this model. Args: - weights(Raft_Large_weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - RAFT: The model. + weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Large_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights + :members: """ weights = Raft_Large_Weights.verify(weights) @@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * @handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT "small" model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `__. Please see the example below for a tutorial on how to use this model. Args: - weights(Raft_Small_weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - RAFT: The model. - + weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Small_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights + :members: """ weights = Raft_Small_Weights.verify(weights)