-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Doc revamp for optical flow models #5895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
8b84e8a
Doc revamp for optical flow models
NicolasHug 992050e
Merge branch 'main' of github.com:pytorch/vision into optical_Flow_do…
NicolasHug a9be003
Merge branch 'main' of github.com:pytorch/vision into optical_Flow_do…
NicolasHug 653f5df
Some more
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
RAFT | ||
==== | ||
|
||
.. currentmodule:: torchvision.models.optical_flow | ||
|
||
The RAFT model is based on the `RAFT: Recurrent All-Pairs Field Transforms for | ||
Optical Flow <https://arxiv.org/abs/2003.12039>`__ paper. | ||
|
||
|
||
Model builders | ||
-------------- | ||
|
||
The following model builders can be used to instantiate a RAFT model, with or | ||
without pre-trained weights. All the model builders internally rely on the | ||
``torchvision.models.optical_flow.RAFT`` base class. Please refer to the `source | ||
code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ for | ||
more details about this class. | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
raft_large | ||
raft_small |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -517,6 +517,19 @@ def forward(self, image1, image2, num_flow_updates: int = 12): | |
|
||
|
||
class Raft_Large_Weights(WeightsEnum): | ||
"""The metrics reported here are as follows. | ||
|
||
``epe`` is the "end-point-error" and indicates how far (in pixels) the | ||
predicted flow is from its true value. This is averaged over all pixels | ||
of all images. ``per_image_epe`` is similar, but the average is different: | ||
the epe is first computed on each image independently, and then averaged | ||
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") | ||
in the original paper, and it's only used on Kitti. ``fl-all`` is also a | ||
Kitti-specific metric, defined by the author of the dataset and used for the | ||
Kitti leaderboard. It corresponds to the average of pixels whose epe is | ||
either <3px, or <5% of flow's 2-norm. | ||
Comment on lines
+522
to
+530
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted this to be a |
||
""" | ||
|
||
C_T_V1 = Weights( | ||
# Weights ported from https://github.com/princeton-vl/RAFT | ||
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", | ||
|
@@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum): | |
"Sintel-Train-Finalpass": {"epe": 2.7894}, | ||
"Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506}, | ||
}, | ||
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", | ||
"_docs": """These weights were ported from the original paper. They | ||
are trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`.""", | ||
}, | ||
) | ||
|
||
|
@@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum): | |
"Sintel-Train-Finalpass": {"epe": 2.7161}, | ||
"Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679}, | ||
}, | ||
"_docs": """These weights were trained from scratch on Chairs + Things.""", | ||
"_docs": """These weights were trained from scratch on | ||
:class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`.""", | ||
}, | ||
) | ||
|
||
|
@@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum): | |
"Sintel-Test-Finalpass": {"epe": 3.18}, | ||
}, | ||
"_docs": """ | ||
These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on | ||
Sintel (C+T+S+K+H). | ||
These weights were ported from the original paper. They are | ||
trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on | ||
Sintel. The Sintel fine-tuning step is a combination of | ||
:class:`~torchvision.datasets.Sintel`, | ||
:class:`~torchvision.datasets.KittiFlow`, | ||
:class:`~torchvision.datasets.HD1K`, and | ||
:class:`~torchvision.datasets.FlyingThings3D` (clean pass). | ||
""", | ||
}, | ||
) | ||
|
@@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum): | |
"Sintel-Test-Finalpass": {"epe": 3.067}, | ||
}, | ||
"_docs": """ | ||
These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H). | ||
These weights were trained from scratch. They are | ||
pre-trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D` and then | ||
fine-tuned on Sintel. The Sintel fine-tuning step is a | ||
combination of :class:`~torchvision.datasets.Sintel`, | ||
:class:`~torchvision.datasets.KittiFlow`, | ||
:class:`~torchvision.datasets.HD1K`, and | ||
:class:`~torchvision.datasets.FlyingThings3D` (clean pass). | ||
""", | ||
}, | ||
) | ||
|
@@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum): | |
"Kitti-Test": {"fl_all": 5.10}, | ||
}, | ||
"_docs": """ | ||
These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on | ||
Sintel and then on Kitti. | ||
These weights were ported from the original paper. They are | ||
pre-trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`, | ||
fine-tuned on Sintel, and then fine-tuned on | ||
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning | ||
step was described above. | ||
""", | ||
}, | ||
) | ||
|
@@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum): | |
"Kitti-Test": {"fl_all": 5.19}, | ||
}, | ||
"_docs": """ | ||
These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti. | ||
These weights were trained from scratch. They are | ||
pre-trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`, | ||
fine-tuned on Sintel, and then fine-tuned on | ||
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning | ||
step was described above. | ||
""", | ||
}, | ||
) | ||
|
@@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum): | |
|
||
|
||
class Raft_Small_Weights(WeightsEnum): | ||
"""The metrics reported here are as follows. | ||
|
||
``epe`` is the "end-point-error" and indicates how far (in pixels) the | ||
predicted flow is from its true value. This is averaged over all pixels | ||
of all images. ``per_image_epe`` is similar, but the average is different: | ||
the epe is first computed on each image independently, and then averaged | ||
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") | ||
in the original paper, and it's only used on Kitti. ``fl-all`` is also a | ||
Kitti-specific metric, defined by the author of the dataset and used for the | ||
Kitti leaderboard. It corresponds to the average of pixels whose epe is | ||
either <3px, or <5% of flow's 2-norm. | ||
""" | ||
|
||
C_T_V1 = Weights( | ||
# Weights ported from https://github.com/princeton-vl/RAFT | ||
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", | ||
|
@@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum): | |
"Sintel-Train-Finalpass": {"epe": 3.2790}, | ||
"Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801}, | ||
}, | ||
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", | ||
"_docs": """These weights were ported from the original paper. They | ||
are trained on :class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`.""", | ||
}, | ||
) | ||
C_T_V2 = Weights( | ||
|
@@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum): | |
"Sintel-Train-Finalpass": {"epe": 3.2831}, | ||
"Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369}, | ||
}, | ||
"_docs": """These weights were trained from scratch on Chairs + Things.""", | ||
"_docs": """These weights were trained from scratch on | ||
:class:`~torchvision.datasets.FlyingChairs` + | ||
:class:`~torchvision.datasets.FlyingThings3D`.""", | ||
}, | ||
) | ||
|
||
|
@@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * | |
Please see the example below for a tutorial on how to use this model. | ||
|
||
Args: | ||
weights(Raft_Large_weights, optional): The pretrained weights for the model | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | ||
to override any default. | ||
|
||
Returns: | ||
RAFT: The model. | ||
weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The | ||
pretrained weights to use. See | ||
:class:`~torchvision.models.optical_flow.Raft_Large_Weights` | ||
below for more details, and possible values. By default, no | ||
pre-trained weights are used. | ||
progress (bool): If True, displays a progress bar of the download to stderr. Default is True. | ||
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` | ||
base class. Please refer to the `source code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ | ||
for more details about this class. | ||
|
||
.. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights | ||
:members: | ||
""" | ||
|
||
weights = Raft_Large_Weights.verify(weights) | ||
|
@@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * | |
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) | ||
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: | ||
"""RAFT "small" model from | ||
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | ||
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__. | ||
|
||
Please see the example below for a tutorial on how to use this model. | ||
|
||
Args: | ||
weights(Raft_Small_weights, optional): The pretrained weights for the model | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | ||
to override any default. | ||
|
||
Returns: | ||
RAFT: The model. | ||
|
||
weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The | ||
pretrained weights to use. See | ||
:class:`~torchvision.models.optical_flow.Raft_Small_Weights` | ||
below for more details, and possible values. By default, no | ||
pre-trained weights are used. | ||
progress (bool): If True, displays a progress bar of the download to stderr. Default is True. | ||
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` | ||
base class. Please refer to the `source code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ | ||
for more details about this class. | ||
|
||
.. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights | ||
:members: | ||
""" | ||
weights = Raft_Small_Weights.verify(weights) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do this? It's the documentation that shows the actual name of the variable. (I know you made it private but still)