diff --git a/docs/source/conf.py b/docs/source/conf.py index b0fe63cb288..bedef5a5215 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -347,10 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines): metrics = meta.pop("metrics", {}) meta_with_metrics = dict(meta, **metrics) - # We don't want to document these, they can be too long - for k in ["categories", "keypoint_names"]: - meta_with_metrics.pop(k, None) - custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs if custom_docs is not None: lines += [custom_docs, ""] @@ -360,6 +356,10 @@ def inject_weight_metadata(app, what, name, obj, options, lines): v = f"`link <{v}>`__" elif k == "min_size": v = f"height={v[0]}, width={v[1]}" + elif k in {"categories", "keypoint_names"} and isinstance(v, list): + max_visible = 3 + v_sample = ", ".join(v[:max_visible]) + v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample table.append((str(k), str(v))) table = tabulate(table, tablefmt="rst") lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css @@ -367,7 +367,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): lines += textwrap.indent(table, " " * 4).split("\n") lines.append("") lines.append( - f"The inference transforms are available at ``{str(field)}.transforms`` and " + f"The preprocessing/inference transforms are available at ``{str(field)}.transforms`` and " f"perform the following operations: {field.transforms().describe()}" ) lines.append("") diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index c208b2c54d3..fe101913f70 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -3,30 +3,42 @@ Models and pre-trained weights - New #################################### -.. note:: - - These are the new models docs, documenting the new multi-weight API. - TODO: Once all is done, remove the "- New" part in the title above, and - rename this file as models.rst - - The ``torchvision.models`` subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow. +General information on pre-trained weights +========================================== + +TorchVision offers pre-trained weights for every provided architecture, using +the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its +weights to a cache directory. This directory can be set using the `TORCH_HOME` +environment variable. See :func:`torch.hub.load_state_dict_from_url` for details. + +.. note:: + + The pre-trained models provided in this library may have their own licenses or + terms and conditions derived from the dataset used for training. It is your + responsibility to determine whether you have permission to use the models for + your use case. + .. note :: - Backward compatibility is guaranteed for loading a serialized - ``state_dict`` to the model created using old PyTorch version. - On the contrary, loading entire saved models or serialized - ``ScriptModules`` (seralized using older versions of PyTorch) - may not preserve the historic behaviour. Refer to the following - `documentation - `_ + Backward compatibility is guaranteed for loading a serialized + ``state_dict`` to the model created using old PyTorch version. + On the contrary, loading entire saved models or serialized + ``ScriptModules`` (serialized using older versions of PyTorch) + may not preserve the historic behaviour. Refer to the following + `documentation + `_ + + +Initializing pre-trained models +------------------------------- As of v0.13, TorchVision offers a new `Multi-weight support API -`_ for loading different weights to the -existing model builder methods: +`_ +for loading different weights to the existing model builder methods: .. code:: python @@ -46,7 +58,7 @@ existing model builder methods: resnet50(weights="IMAGENET1K_V2") # No weights - random initialization - resnet50(weights=None) # or resnet50() + resnet50(weights=None) Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent: @@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet # Using pretrained weights: resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + resnet50(weights="IMAGENET1K_V1") resnet50(pretrained=True) # deprecated resnet50(True) # deprecated # Using no weights: resnet50(weights=None) + resnet50() resnet50(pretrained=False) # deprecated resnet50(False) # deprecated Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15. +Using the pre-trained models +---------------------------- + +Before using the pre-trained models, one must preprocess the image +(resize with right resolution/interpolation, apply inference transforms, +rescale the values etc). There is no standard way to do this as it depends on +how a given model was trained. It can vary across model families, variants or +even weight versions. Using the correct preprocessing method is critical and +failing to do so may lead to decreased accuracy or incorrect outputs. + +All the necessary information for the inference transforms of each pre-trained +model is provided on its weights documentation. To simplify inference, TorchVision +bundles the necessary preprocessing transforms into each model weight. These are +accessible via the ``weight.transforms`` attribute: + +.. code:: python + + # Initialize the Weight Transforms + weights = ResNet50_Weights.DEFAULT + preprocess = weights.transforms() + + # Apply it to the input image + img_transformed = preprocess(img) + + +Some models use modules which have different training and evaluation +behavior, such as batch normalization. To switch between these modes, use +``model.train()`` or ``model.eval()`` as appropriate. See +:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details. + +.. code:: python + + # Initialize model + weights = ResNet50_Weights.DEFAULT + model = resnet50(weights=weights) + + # Set model to eval mode + model.eval() + Classification ============== @@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models: category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score:.1f}%") +The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``. + Table of all available classification weights --------------------------------------------- -Accuracies are reported on ImageNet +Accuracies are reported on ImageNet-1K using single crops: .. include:: generated/classification_table.rst @@ -140,7 +195,7 @@ Quantized models .. currentmodule:: torchvision.models.quantization -The following quantized classification models are available, with or without +The following architectures provide support for INT8 quantized models, with or without pre-trained weights: .. toctree:: @@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score}%") +The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``. + Table of all available quantized classification weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Accuracies are reported on ImageNet +Accuracies are reported on ImageNet-1K using single crops: .. include:: generated/classification_quant_table.rst @@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models: mask = normalized_masks[0, class_to_idx["dog"]] to_pil_image(mask).show() +The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``. +The output format of the models is illustrated in :ref:`semantic_seg_output`. + Table of all available semantic segmentation weights ---------------------------------------------------- -All models are evaluated on COCO val2017: +All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset: .. include:: generated/segmentation_table.rst @@ -247,6 +307,11 @@ All models are evaluated on COCO val2017: Object Detection, Instance Segmentation and Person Keypoint Detection ===================================================================== +The pre-trained models for detection, instance segmentation and +keypoint detection are initialized with the classification models +in torchvision. The models expect a list of ``Tensor[C, H, W]``. +Check the constructor of the models for more information. + Object Detection ---------------- @@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models: im = to_pil_image(box.detach()) im.show() +The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``. +For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`. + Table of all available Object detection weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Box MAPs are reported on COCO +Box MAPs are reported on COCO val2017: .. include:: generated/detection_table.rst @@ -319,10 +387,15 @@ weights: models/mask_rcnn +| + + +For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`. + Table of all available Instance segmentation weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Box and Mask MAPs are reported on COCO +Box and Mask MAPs are reported on COCO val2017: .. include:: generated/instance_segmentation_table.rst @@ -331,7 +404,7 @@ Keypoint Detection .. currentmodule:: torchvision.models.detection -The following keypoint detection models are available, with or without +The following person keypoint detection models are available, with or without pre-trained weights: .. toctree:: @@ -339,10 +412,15 @@ pre-trained weights: models/keypoint_rcnn +| + +The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``. +For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`. + Table of all available Keypoint detection weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Box and Keypoint MAPs are reported on COCO: +Box and Keypoint MAPs are reported on COCO val2017: .. include:: generated/detection_keypoint_table.rst @@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models: category_name = weights.meta["categories"][label] print(f"{category_name}: {100 * score}%") +The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``. + Table of all available video classification weights --------------------------------------------------- -Accuracies are reported on Kinetics-400 +Accuracies are reported on Kinetics-400 using single crops for clip length 16: .. include:: generated/video_table.rst + +Using models from Hub +===================== + +Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed: + +.. code:: python + + import torch + + # Option 1: passing weights param as string + model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2") + + # Option 2: passing weights param as enum + weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2") + model = torch.hub.load("pytorch/vision", "resnet50", weights=weights) + +The only exception to the above are the detection models included on +:mod:`torchvision.models.detection`. These models require TorchVision +to be installed because they depend on custom C++ operators. diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 7f92d54ebdd..843ebd3c247 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -379,6 +379,8 @@ def show(imgs): # instance with class 15 (which corresponds to 'bench') was not selected. ##################################### +# .. _keypoint_output: +# # Visualizing keypoints # ------------------------------ # The :func:`~torchvision.utils.draw_keypoints` function can be used to diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index b009d45f1a4..765ae8ec3c4 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -71,8 +71,8 @@ def __repr__(self) -> str: def describe(self) -> str: return ( f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " - f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to " - f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." ) @@ -127,8 +127,8 @@ def __repr__(self) -> str: def describe(self) -> str: return ( f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " - f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to " - f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." ) @@ -168,7 +168,8 @@ def __repr__(self) -> str: def describe(self) -> str: return ( f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " - f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " + f"``std={self.std}``." )