From ae8dfd216480a4fbe8e20b03a461c3445ba0d1ba Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 25 Apr 2022 14:49:27 +0100 Subject: [PATCH 1/3] Start doc revamp for detection models --- docs/source/conf.py | 20 +++++++------- docs/source/models/retinanet.rst | 23 ++++++++++++++++ docs/source/models_new.rst | 17 +++++++++++- torchvision/models/detection/retinanet.py | 32 +++++++++++++++++------ 4 files changed, 74 insertions(+), 18 deletions(-) create mode 100644 docs/source/models/retinanet.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 9966fc0cbf8..e50cda3a5e2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,6 +22,7 @@ import os import textwrap +from copy import copy from pathlib import Path import pytorch_sphinx_theme @@ -330,7 +331,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): # the `meta` dict contains another embedded `metrics` dict. To # simplify the table generation below, we create the # `meta_with_metrics` dict, where the metrics dict has been "flattened" - meta = field.meta + meta = copy(field.meta) metrics = meta.pop("metrics", {}) meta_with_metrics = dict(meta, **metrics) @@ -346,17 +347,17 @@ def inject_weight_metadata(app, what, name, obj, options, lines): lines.append("") -def generate_classification_table(): +def generate_weights_table(module, table_name, metrics_names): - weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")] + weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")] weights = [w for weight_enum in weight_enums for w in weight_enum] - column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**") + column_names = ["weight"] + list(metrics_names) + ["params", "recipe"] + column_names = [f"**{name}**" for name in column_names] content = [ ( f":class:`{w} <{type(w).__name__}>`", - w.meta["metrics"]["acc@1"], - w.meta["metrics"]["acc@5"], + *tuple(w.meta["metrics"][metric.lower()] for metric in metrics_names), f"{w.meta['num_params']/1e6:.1f}M", f"`link <{w.meta['recipe']}>`__", ) @@ -366,13 +367,14 @@ def generate_classification_table(): generated_dir = Path("generated") generated_dir.mkdir(exist_ok=True) - with open(generated_dir / "classification_table.rst", "w+") as table_file: + with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file: table_file.write(".. table::\n") - table_file.write(" :widths: 100 10 10 20 10\n\n") + table_file.write(f" :widths: 100 {'15 ' * len(metrics_names)} 20 10\n\n") table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n") -generate_classification_table() +generate_weights_table(module=M, table_name="classification", metrics_names=["Acc@1", "Acc@5"]) +generate_weights_table(module=M.detection, table_name="detection", metrics_names=["box_map"]) def setup(app): diff --git a/docs/source/models/retinanet.rst b/docs/source/models/retinanet.rst new file mode 100644 index 00000000000..3475cc783c3 --- /dev/null +++ b/docs/source/models/retinanet.rst @@ -0,0 +1,23 @@ +RetinaNet +========= + +.. currentmodule:: torchvision.models.detection + +The RetinaNet model is based on the `Focal Loss for Dense Object Detection +`__ paper. + +Model builders +-------------- + +The following model builders can be used to instantiate a RetinaNet model, with or +without pre-trained weights. All the model buidlers internally rely on the +``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + retinanet_resnet50_fpn + retinanet_resnet50_fpn_v2 diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index 77756b634b2..782fa2b5bde 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -58,4 +58,19 @@ Accuracies are reported on ImageNet Object Detection, Instance Segmentation and Person Keypoint Detection ===================================================================== -TODO: Something similar to classification models: list of models + table of weights +.. currentmodule:: torchvision.models.detection + +The following detection models are available, with or without pre-trained +weights: + +.. toctree:: + :maxdepth: 1 + + models/retinanet + +Table of all available detection weights +---------------------------------------- + +Box MAPs are reported on COCO + +.. include:: generated/detection_table.rst diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 671eab864a2..3ea739d68d0 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -727,7 +727,7 @@ def retinanet_resnet50_fpn( """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. - Reference: `"Focal Loss for Dense Object Detection" `_. + Reference: `Focal Loss for Dense Object Detection `_. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. @@ -763,13 +763,21 @@ def retinanet_resnet50_fpn( >>> predictions = model(x) Args: - weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) - weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for + the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. + + .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights + :members: """ weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) weights_backbone = ResNet50_Weights.verify(weights_backbone) @@ -811,19 +819,27 @@ def retinanet_resnet50_fpn_v2( """ Constructs an improved RetinaNet model with a ResNet-50-FPN backbone. - Reference: `"Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection" + Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection `_. :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details. Args: - weights (RetinaNet_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) - weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for + the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. + + .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights + :members: """ weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights) weights_backbone = ResNet50_Weights.verify(weights_backbone) From 374d512e3a80e9d4a31a5ab367d64d29caf1a446 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 25 Apr 2022 15:32:07 +0100 Subject: [PATCH 2/3] Minor cleanup --- docs/source/conf.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index e50cda3a5e2..1bb5c30ce09 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -352,12 +352,17 @@ def generate_weights_table(module, table_name, metrics_names): weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")] weights = [w for weight_enum in weight_enums for w in weight_enum] - column_names = ["weight"] + list(metrics_names) + ["params", "recipe"] - column_names = [f"**{name}**" for name in column_names] + def clean_name(name): + if name == "box_map": + name = "Box MAP" + return f"**{name}**" # Add bold + + column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"] + column_names = [clean_name(name) for name in column_names] content = [ ( f":class:`{w} <{type(w).__name__}>`", - *tuple(w.meta["metrics"][metric.lower()] for metric in metrics_names), + *(w.meta["metrics"][metric.lower()] for metric in metrics_names), f"{w.meta['num_params']/1e6:.1f}M", f"`link <{w.meta['recipe']}>`__", ) @@ -369,7 +374,7 @@ def generate_weights_table(module, table_name, metrics_names): generated_dir.mkdir(exist_ok=True) with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file: table_file.write(".. table::\n") - table_file.write(f" :widths: 100 {'15 ' * len(metrics_names)} 20 10\n\n") + table_file.write(f" :widths: 100 {'20 ' * len(metrics_names)} 20 10\n\n") table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n") From 56f8a48c8293a3f948ceb63a891d41610b34fb74 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 26 Apr 2022 14:19:40 +0100 Subject: [PATCH 3/3] Use list of tuples for metrics --- docs/source/conf.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1bb5c30ce09..01f48282e32 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -347,22 +347,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines): lines.append("") -def generate_weights_table(module, table_name, metrics_names): - +def generate_weights_table(module, table_name, metrics): weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")] weights = [w for weight_enum in weight_enums for w in weight_enum] - def clean_name(name): - if name == "box_map": - name = "Box MAP" - return f"**{name}**" # Add bold - + metrics_keys, metrics_names = zip(*metrics) column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"] - column_names = [clean_name(name) for name in column_names] + column_names = [f"**{name}**" for name in column_names] # Add bold + content = [ ( f":class:`{w} <{type(w).__name__}>`", - *(w.meta["metrics"][metric.lower()] for metric in metrics_names), + *(w.meta["metrics"][metric] for metric in metrics_keys), f"{w.meta['num_params']/1e6:.1f}M", f"`link <{w.meta['recipe']}>`__", ) @@ -378,8 +374,8 @@ def clean_name(name): table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n") -generate_weights_table(module=M, table_name="classification", metrics_names=["Acc@1", "Acc@5"]) -generate_weights_table(module=M.detection, table_name="detection", metrics_names=["box_map"]) +generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]) +generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")]) def setup(app):