Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items():
metric_name = metric_name.replace("_", "-")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this? It's the documentation that shows the actual name of the variable. (I know you made it private but still)

table.append((f"{metric_name} (on {dataset})", str(metric_value)))

for k, v in meta.items():
Expand Down
25 changes: 25 additions & 0 deletions docs/source/models/raft.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
RAFT
====

.. currentmodule:: torchvision.models.optical_flow

The RAFT model is based on the `RAFT: Recurrent All-Pairs Field Transforms for
Optical Flow <https://arxiv.org/abs/2003.12039>`__ paper.


Model builders
--------------

The following model builders can be used to instantiate a RAFT model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.optical_flow.RAFT`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

raft_large
raft_small
13 changes: 13 additions & 0 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ Box MAPs are reported on COCO val2017:

.. include:: generated/detection_table.rst


Instance Segmentation
---------------------

Expand Down Expand Up @@ -481,6 +482,18 @@ Accuracies are reported on Kinetics-400 using single crops for clip length 16:

.. include:: generated/video_table.rst

Optical Flow
============

.. currentmodule:: torchvision.models.optical_flow

The following Optical Flow models are available, with or without pre-trained

.. toctree::
:maxdepth: 1

models/raft

Using models from Hub
=====================

Expand Down
119 changes: 93 additions & 26 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,19 @@ def forward(self, image1, image2, num_flow_updates: int = 12):


class Raft_Large_Weights(WeightsEnum):
"""The metrics reported here are as follows.

``epe`` is the "end-point-error" and indicates how far (in pixels) the
predicted flow is from its true value. This is averaged over all pixels
of all images. ``per_image_epe`` is similar, but the average is different:
the epe is first computed on each image independently, and then averaged
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
Kitti-specific metric, defined by the author of the dataset and used for the
Kitti leaderboard. It corresponds to the average of pixels whose epe is
either <3px, or <5% of flow's 2-norm.
Comment on lines +522 to +530
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted this to be a ..note :: but it looks like it doesn't render properly. It's likely because we treat these docstrings are treated differently from the rest.

"""

C_T_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
Expand All @@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 2.7894},
"Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
},
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""",
"_docs": """These weights were ported from the original paper. They
are trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
},
)

Expand All @@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 2.7161},
"Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
},
"_docs": """These weights were trained from scratch on Chairs + Things.""",
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
},
)

Expand All @@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Test-Finalpass": {"epe": 3.18},
},
"_docs": """
These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on
Sintel (C+T+S+K+H).
These weights were ported from the original paper. They are
trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
Sintel. The Sintel fine-tuning step is a combination of
:class:`~torchvision.datasets.Sintel`,
:class:`~torchvision.datasets.KittiFlow`,
:class:`~torchvision.datasets.HD1K`, and
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
""",
},
)
Expand All @@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Test-Finalpass": {"epe": 3.067},
},
"_docs": """
These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H).
These weights were trained from scratch. They are
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D` and then
fine-tuned on Sintel. The Sintel fine-tuning step is a
combination of :class:`~torchvision.datasets.Sintel`,
:class:`~torchvision.datasets.KittiFlow`,
:class:`~torchvision.datasets.HD1K`, and
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
""",
},
)
Expand All @@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum):
"Kitti-Test": {"fl_all": 5.10},
},
"_docs": """
These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on
Sintel and then on Kitti.
These weights were ported from the original paper. They are
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`,
fine-tuned on Sintel, and then fine-tuned on
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
step was described above.
""",
},
)
Expand All @@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum):
"Kitti-Test": {"fl_all": 5.19},
},
"_docs": """
These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti.
These weights were trained from scratch. They are
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`,
fine-tuned on Sintel, and then fine-tuned on
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
step was described above.
""",
},
)
Expand All @@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum):


class Raft_Small_Weights(WeightsEnum):
"""The metrics reported here are as follows.

``epe`` is the "end-point-error" and indicates how far (in pixels) the
predicted flow is from its true value. This is averaged over all pixels
of all images. ``per_image_epe`` is similar, but the average is different:
the epe is first computed on each image independently, and then averaged
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
Kitti-specific metric, defined by the author of the dataset and used for the
Kitti leaderboard. It corresponds to the average of pixels whose epe is
either <3px, or <5% of flow's 2-norm.
"""

C_T_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
Expand All @@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 3.2790},
"Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
},
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""",
"_docs": """These weights were ported from the original paper. They
are trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
},
)
C_T_V2 = Weights(
Expand All @@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 3.2831},
"Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
},
"_docs": """These weights were trained from scratch on Chairs + Things.""",
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
},
)

Expand Down Expand Up @@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
Please see the example below for a tutorial on how to use this model.

Args:
weights(Raft_Large_weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
RAFT: The model.
weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.optical_flow.Raft_Large_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.

.. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
:members:
"""

weights = Raft_Large_Weights.verify(weights)
Expand Down Expand Up @@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.

Please see the example below for a tutorial on how to use this model.

Args:
weights(Raft_Small_weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
RAFT: The model.

weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.optical_flow.Raft_Small_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.

.. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
:members:
"""
weights = Raft_Small_Weights.verify(weights)

Expand Down