Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import os
import textwrap
from copy import copy
from pathlib import Path

import pytorch_sphinx_theme
Expand Down Expand Up @@ -330,7 +331,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
# the `meta` dict contains another embedded `metrics` dict. To
# simplify the table generation below, we create the
# `meta_with_metrics` dict, where the metrics dict has been "flattened"
meta = field.meta
meta = copy(field.meta)
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)

Expand All @@ -346,17 +347,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("")


def generate_classification_table():

weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
def generate_weights_table(module, table_name, metrics):
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum]

column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
column_names = [f"**{name}**" for name in column_names] # Add bold

content = [
(
f":class:`{w} <{type(w).__name__}>`",
w.meta["metrics"]["acc@1"],
w.meta["metrics"]["acc@5"],
*(w.meta["metrics"][metric] for metric in metrics_keys),
f"{w.meta['num_params']/1e6:.1f}M",
f"`link <{w.meta['recipe']}>`__",
)
Expand All @@ -366,13 +368,14 @@ def generate_classification_table():

generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / "classification_table.rst", "w+") as table_file:
with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file:
table_file.write(".. table::\n")
table_file.write(" :widths: 100 10 10 20 10\n\n")
table_file.write(f" :widths: 100 {'20 ' * len(metrics_names)} 20 10\n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")


generate_classification_table()
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")])


def setup(app):
Expand Down
23 changes: 23 additions & 0 deletions docs/source/models/retinanet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
RetinaNet
=========

.. currentmodule:: torchvision.models.detection

The RetinaNet model is based on the `Focal Loss for Dense Object Detection
<https://arxiv.org/abs/1708.02002>`__ paper.

Model builders
--------------

The following model builders can be used to instantiate a RetinaNet model, with or
without pre-trained weights. All the model buidlers internally rely on the
``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

retinanet_resnet50_fpn
retinanet_resnet50_fpn_v2
17 changes: 16 additions & 1 deletion docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,19 @@ Accuracies are reported on ImageNet
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

TODO: Something similar to classification models: list of models + table of weights
.. currentmodule:: torchvision.models.detection

The following detection models are available, with or without pre-trained
weights:

.. toctree::
:maxdepth: 1

models/retinanet

Table of all available detection weights
----------------------------------------

Box MAPs are reported on COCO

.. include:: generated/detection_table.rst
32 changes: 24 additions & 8 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def retinanet_resnet50_fpn(
"""
Constructs a RetinaNet model with a ResNet-50-FPN backbone.

Reference: `"Focal Loss for Dense Object Detection" <https://arxiv.org/abs/1708.02002>`_.
Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.

The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
Expand Down Expand Up @@ -763,13 +763,21 @@ def retinanet_resnet50_fpn(
>>> predictions = model(x)

Args:
weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.

.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
:members:
"""
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
Expand Down Expand Up @@ -811,19 +819,27 @@ def retinanet_resnet50_fpn_v2(
"""
Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.

Reference: `"Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection"
Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
<https://arxiv.org/abs/1912.02424>`_.

:func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.

Args:
weights (RetinaNet_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.

.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
:members:
"""
weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
Expand Down