Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ docs/build
docs/source/auto_examples/
docs/source/gen_modules/
docs/source/generated/
docs/source/models/generated/
# pytorch-sphinx-theme gets installed here
docs/src

Expand Down
4 changes: 3 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ ifneq ($(EXAMPLES_PATTERN),)
endif

# You can set these variables from the command line.
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
# TODO: Once the models doc revamp is done, set back the -W option to raise
# errors on warnings. See https://github.com/pytorch/vision/pull/5821#discussion_r850500693
SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -W flag will make the docs error on warnings. It's good to have to avoid having broken references. But since I removed almost all of the models.rst file, a few of the refs (e.g. from gallery examples) are broken. To avoid noise I'm removing the flag for now.

Will definitely put it back once everything is done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a TODO or create an issue to avoid forgetting?

SPHINXBUILD = sphinx-build
SPHINXPROJ = torchvision
SOURCEDIR = source
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ numpy
sphinx-copybutton>=0.3.1
sphinx-gallery>=0.9.0
sphinx==3.5.4
tabulate
# This pin is only needed for sphinx<4.0.2. See https://github.com/pytorch/vision/issues/5673 for details
Jinja2<3.1.*
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
56 changes: 56 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
# sys.path.insert(0, os.path.abspath('.'))

import os
import textwrap
from pathlib import Path

import pytorch_sphinx_theme
import torchvision
import torchvision.models as M
from tabulate import tabulate


# -- General configuration ------------------------------------------------
Expand Down Expand Up @@ -292,5 +296,57 @@ def inject_minigalleries(app, what, name, obj, options, lines):
lines.append("\n")


def inject_weight_metadata(app, what, name, obj, options, lines):

if obj.__name__.endswith("_Weights"):
lines[:] = ["The model builder above accepts the following values as the ``weights`` parameter:"]
lines.append("")
for field in obj:
lines += [f"**{str(field)}**:", ""]

table = []
for k, v in field.meta.items():
if k == "categories":
continue
elif k == "recipe":
v = f"`link <{v}>`__"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")


def generate_classification_table():

weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum]

column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
content = [
(
f":class:`{w} <{type(w).__name__}>`",
w.meta["acc@1"],
w.meta["acc@5"],
f"{w.meta['num_params']/1e6:.1f}M",
f"`link <{w.meta['recipe']}>`__",
)
for w in weights
]
table = tabulate(content, headers=column_names, tablefmt="rst")

generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / "classification_table.rst", "w+") as table_file:
table_file.write(".. table::\n")
table_file.write(" :widths: 100 10 10 20 10\n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")


generate_classification_table()


def setup(app):

app.connect("autodoc-process-docstring", inject_minigalleries)
app.connect("autodoc-process-docstring", inject_weight_metadata)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ architectures, and common image transformations for computer vision.
ops
io
feature_extraction
models_new

.. toctree::
:maxdepth: 1
Expand Down
28 changes: 28 additions & 0 deletions docs/source/models/resnet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
ResNet
======

.. currentmodule:: torchvision.models

The ResNet model is based on the `Deep Residual Learning for Image Recognition
<https://arxiv.org/abs/1512.03385>`_ paper.


Model builders
--------------

The following model builders can be used to instanciate a ResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.resnet.ResNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

resnet18
resnet34
resnet50
resnet101
resnet152
30 changes: 30 additions & 0 deletions docs/source/models/vgg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
VGG
===

.. currentmodule:: torchvision.models

The VGG model is based on the `Very Deep Convolutional Networks for Large-Scale
Image Recognition <https://arxiv.org/abs/1409.1556>`_ paper.


Model builders
--------------

The following model builders can be used to instanciate a VGG model, with or
without pre-trained weights. All the model buidlers internally rely on the
``torchvision.models.vgg.VGG`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn
54 changes: 54 additions & 0 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
.. _models_new:

Models and pre-trained weights - New
####################################

.. note::

These are the new models docs, documenting the new multi-weight API.
TODO: Once all is done, remove the "- New" part in the title above, and
rename this file as models.rst


The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.

.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (seralized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_


Classification
==============

.. currentmodule:: torchvision.models

The following classification models are available, with or without pre-trained
weights:

.. toctree::
:maxdepth: 1

models/resnet
models/vgg


Table of all available classification weights
---------------------------------------------

Accuracies are reported on ImageNet

.. include:: generated/classification_table.rst


Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

TODO: Something similar to classification models: list of models + table of weights
95 changes: 75 additions & 20 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,23 @@ class Wide_ResNet101_2_Weights(WeightsEnum):

@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.

Args:
weights (ResNet18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet18_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.

.. autoclass:: torchvision.models.ResNet18_Weights
:members:
"""
weights = ResNet18_Weights.verify(weights)

Expand All @@ -570,12 +581,23 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru

@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.

Args:
weights (ResNet34_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet34_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.

.. autoclass:: torchvision.models.ResNet34_Weights
:members:
"""
weights = ResNet34_Weights.verify(weights)

Expand All @@ -584,12 +606,23 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru

@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.

Args:
weights (ResNet50_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet50_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.

.. autoclass:: torchvision.models.ResNet50_Weights
:members:
"""
weights = ResNet50_Weights.verify(weights)

Expand All @@ -598,12 +631,23 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru

@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.

Args:
weights (ResNet101_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet101_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.

.. autoclass:: torchvision.models.ResNet101_Weights
:members:
"""
weights = ResNet101_Weights.verify(weights)

Expand All @@ -612,12 +656,23 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T

@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.

Args:
weights (ResNet152_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet152_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.

.. autoclass:: torchvision.models.ResNet152_Weights
:members:
"""
weights = ResNet152_Weights.verify(weights)

Expand Down
Loading