diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index 43374582f2a..1d349afdbdd 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -24,6 +24,49 @@ keypoint detection, video classification, and optical flow. `documentation `_ +As of v0.13, TorchVision offers a new `Multi-weight support API +`_ for loading different weights to the +existing model builder methods: + +.. code:: python + + from torchvision.models import resnet50, ResNet50_Weights + + # Old weights with accuracy 76.130% + resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + + # New weights with accuracy 80.858% + resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + + # Best available weights (currently alias for IMAGENET1K_V2) + # Note that these weights may change across versions + resnet50(weights=ResNet50_Weights.DEFAULT) + + # Strings are also supported + resnet50(weights="IMAGENET1K_V2") + + # No weights - random initialization + resnet50(weights=None) # or resnet50() + + +Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent: + +.. code:: python + + from torchvision.models import resnet50, ResNet50_Weights + + # Using pretrained weights: + resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + resnet50(pretrained=True) # deprecated + resnet50(True) # deprecated + + # Using no weights: + resnet50(weights=None) + resnet50(pretrained=False) # deprecated + resnet50(False) # deprecated + +Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15. + Classification ============== @@ -56,6 +99,34 @@ weights: models/vision_transformer models/wide_resnet +| + +Here is an example of how to use the pre-trained image classification models: + +.. code:: python + + from torchvision.io import read_image + from torchvision.models import resnet50, ResNet50_Weights + + img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + + # Step 1: Initialize model with the best available weights + weights = ResNet50_Weights.DEFAULT + model = resnet50(weights=weights) + model.eval() + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = preprocess(img).unsqueeze(0) + + # Step 4: Use the model and print the predicted category + prediction = model(batch).squeeze(0).softmax(0) + class_id = prediction.argmax().item() + score = prediction[class_id].item() + category_name = weights.meta["categories"][class_id] + print(f"{category_name}: {100 * score:.1f}%") Table of all available classification weights --------------------------------------------- @@ -78,6 +149,35 @@ pre-trained weights: models/googlenet_quant models/mobilenetv2_quant +| + +Here is an example of how to use the pre-trained quantized image classification models: + +.. code:: python + + from torchvision.io import read_image + from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights + + img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + + # Step 1: Initialize model with the best available weights + weights = ResNet50_QuantizedWeights.DEFAULT + model = resnet50(weights=weights, quantize=True) + model.eval() + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = preprocess(img).unsqueeze(0) + + # Step 4: Use the model and print the predicted category + prediction = model(batch).squeeze(0).softmax(0) + class_id = prediction.argmax().item() + score = prediction[class_id].item() + category_name = weights.meta["categories"][class_id] + print(f"{category_name}: {100 * score}%") + Table of all available quantized classification weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -101,6 +201,37 @@ pre-trained weights: models/fcn models/lraspp +| + +Here is an example of how to use the pre-trained semantic segmentation models: + +.. code:: python + + from torchvision.io.image import read_image + from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights + from torchvision.transforms.functional import to_pil_image + + img = read_image("gallery/assets/dog1.jpg") + + # Step 1: Initialize model with the best available weights + weights = FCN_ResNet50_Weights.DEFAULT + model = fcn_resnet50(weights=weights) + model.eval() + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = preprocess(img).unsqueeze(0) + + # Step 4: Use the model and visualize the prediction + prediction = model(batch)["out"] + normalized_masks = prediction.softmax(dim=1) + class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])} + mask = normalized_masks[0, class_to_idx["dog"]] + to_pil_image(mask).show() + + Table of all available semantic segmentation weights ---------------------------------------------------- @@ -130,6 +261,41 @@ weights: models/ssd models/ssdlite +| + +Here is an example of how to use the pre-trained object detection models: + +.. code:: python + + + from torchvision.io.image import read_image + from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights + from torchvision.utils import draw_bounding_boxes + from torchvision.transforms.functional import to_pil_image + + img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + + # Step 1: Initialize model with the best available weights + weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT + model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) + model.eval() + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = [preprocess(img)] + + # Step 4: Use the model and visualize the prediction + prediction = model(batch)[0] + labels = [weights.meta["categories"][i] for i in prediction["labels"]] + box = draw_bounding_boxes(img, boxes=prediction["boxes"], + labels=labels, + colors="red", + width=4, font_size=30) + im = to_pil_image(box.detach()) + im.show() + Table of all available Object detection weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -191,6 +357,38 @@ pre-trained weights: models/video_resnet +| + +Here is an example of how to use the pre-trained video classification models: + +.. code:: python + + + from torchvision.io.video import read_video + from torchvision.models.video import r3d_18, R3D_18_Weights + + vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi") + vid = vid[:32] # optionally shorten duration + + # Step 1: Initialize model with the best available weights + weights = R3D_18_Weights.DEFAULT + model = r3d_18(weights=weights) + model.eval() + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = preprocess(vid).unsqueeze(0) + + # Step 4: Use the model and print the predicted category + prediction = model(batch).squeeze(0).softmax(0) + label = prediction.argmax().item() + score = prediction[label].item() + category_name = weights.meta["categories"][label] + print(f"{category_name}: {100 * score}%") + + Table of all available video classification weights ---------------------------------------------------