Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,49 @@ keypoint detection, video classification, and optical flow.
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_

As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
existing model builder methods:

.. code:: python

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None) # or resnet50()


Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:

.. code:: python

from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated

# Using no weights:
resnet50(weights=None)
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated

Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.


Classification
==============
Expand Down Expand Up @@ -56,6 +99,34 @@ weights:
models/vision_transformer
models/wide_resnet

|

Here is an example of how to use the pre-trained image classification models:

.. code:: python

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

Table of all available classification weights
---------------------------------------------
Expand All @@ -78,6 +149,35 @@ pre-trained weights:
models/googlenet_quant
models/mobilenetv2_quant

|

Here is an example of how to use the pre-trained quantized image classification models:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should just refer to the snippet above instead, and mention to pass quantize=True? This seems to be the only difference, apart from the imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK giving a second example given we have a different namespace. Hopefully on the future this will be deprecated and move all quantization of models on main model builders (with FX quant)


.. code:: python

from torchvision.io import read_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")


Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -101,6 +201,37 @@ pre-trained weights:
models/fcn
models/lraspp

|

Here is an example of how to use the pre-trained semantic segmentation models:

.. code:: python

from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = read_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()


Table of all available semantic segmentation weights
----------------------------------------------------

Expand Down Expand Up @@ -130,6 +261,41 @@ weights:
models/ssd
models/ssdlite

|

Here is an example of how to use the pre-trained object detection models:

.. code:: python


from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
Copy link
Member

@NicolasHug NicolasHug May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting

RuntimeError: "qnms_kernel" not implemented for 'Float'

But I haven't compiled or updated the nightlies recently, maybe that's why

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's strange. It works for me. Can you put the recent in case this is a bug?


Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -191,6 +357,38 @@ pre-trained weights:

models/video_resnet

|

Here is an example of how to use the pre-trained video classification models:

.. code:: python


from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights

vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
Copy link
Contributor Author

@datumbox datumbox May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bjuncek Apparently this emits a vague deprecation warning (without a fixed removal date). Does it make sense to remove the warning given that we are still uncertain on what we will do with the specific API? Or perhaps I should modify my example to avoid the warnings?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, feel free to remove it.
I'm in the process of re-writing the video-module to have a more sensical stucture

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #6056

vid = vid[:32] # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")


Table of all available video classification weights
---------------------------------------------------

Expand Down