-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add examples of Multi-weight support + model usage #6013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a5231db
73797d2
6cc9be4
21c2c67
3736ca8
8c32ee0
8f2a23c
a99cb39
7aa2fc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,49 @@ keypoint detection, video classification, and optical flow. | |
`documentation | ||
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_ | ||
|
||
As of v0.13, TorchVision offers a new `Multi-weight support API | ||
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the | ||
existing model builder methods: | ||
|
||
.. code:: python | ||
|
||
from torchvision.models import resnet50, ResNet50_Weights | ||
|
||
# Old weights with accuracy 76.130% | ||
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | ||
|
||
# New weights with accuracy 80.858% | ||
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | ||
|
||
# Best available weights (currently alias for IMAGENET1K_V2) | ||
# Note that these weights may change across versions | ||
resnet50(weights=ResNet50_Weights.DEFAULT) | ||
|
||
# Strings are also supported | ||
resnet50(weights="IMAGENET1K_V2") | ||
|
||
# No weights - random initialization | ||
resnet50(weights=None) # or resnet50() | ||
|
||
|
||
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent: | ||
|
||
.. code:: python | ||
|
||
from torchvision.models import resnet50, ResNet50_Weights | ||
|
||
# Using pretrained weights: | ||
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | ||
resnet50(pretrained=True) # deprecated | ||
resnet50(True) # deprecated | ||
|
||
# Using no weights: | ||
resnet50(weights=None) | ||
resnet50(pretrained=False) # deprecated | ||
resnet50(False) # deprecated | ||
|
||
Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15. | ||
|
||
|
||
Classification | ||
============== | ||
|
@@ -56,6 +99,34 @@ weights: | |
models/vision_transformer | ||
models/wide_resnet | ||
|
||
| | ||
|
||
Here is an example of how to use the pre-trained image classification models: | ||
|
||
.. code:: python | ||
|
||
from torchvision.io import read_image | ||
from torchvision.models import resnet50, ResNet50_Weights | ||
|
||
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") | ||
|
||
# Step 1: Initialize model with the best available weights | ||
weights = ResNet50_Weights.DEFAULT | ||
model = resnet50(weights=weights) | ||
model.eval() | ||
|
||
# Step 2: Initialize the inference transforms | ||
preprocess = weights.transforms() | ||
|
||
# Step 3: Apply inference preprocessing transforms | ||
batch = preprocess(img).unsqueeze(0) | ||
|
||
# Step 4: Use the model and print the predicted category | ||
prediction = model(batch).squeeze(0).softmax(0) | ||
class_id = prediction.argmax().item() | ||
score = prediction[class_id].item() | ||
category_name = weights.meta["categories"][class_id] | ||
print(f"{category_name}: {100 * score:.1f}%") | ||
|
||
Table of all available classification weights | ||
--------------------------------------------- | ||
|
@@ -78,6 +149,35 @@ pre-trained weights: | |
models/googlenet_quant | ||
models/mobilenetv2_quant | ||
|
||
| | ||
|
||
Here is an example of how to use the pre-trained quantized image classification models: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should just refer to the snippet above instead, and mention to pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm OK giving a second example given we have a different namespace. Hopefully on the future this will be deprecated and move all quantization of models on main model builders (with FX quant) |
||
|
||
.. code:: python | ||
|
||
from torchvision.io import read_image | ||
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights | ||
|
||
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") | ||
|
||
# Step 1: Initialize model with the best available weights | ||
weights = ResNet50_QuantizedWeights.DEFAULT | ||
model = resnet50(weights=weights, quantize=True) | ||
model.eval() | ||
|
||
# Step 2: Initialize the inference transforms | ||
preprocess = weights.transforms() | ||
|
||
# Step 3: Apply inference preprocessing transforms | ||
batch = preprocess(img).unsqueeze(0) | ||
|
||
# Step 4: Use the model and print the predicted category | ||
prediction = model(batch).squeeze(0).softmax(0) | ||
class_id = prediction.argmax().item() | ||
score = prediction[class_id].item() | ||
category_name = weights.meta["categories"][class_id] | ||
print(f"{category_name}: {100 * score}%") | ||
|
||
|
||
Table of all available quantized classification weights | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
@@ -101,6 +201,37 @@ pre-trained weights: | |
models/fcn | ||
models/lraspp | ||
|
||
| | ||
|
||
Here is an example of how to use the pre-trained semantic segmentation models: | ||
|
||
.. code:: python | ||
|
||
from torchvision.io.image import read_image | ||
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights | ||
from torchvision.transforms.functional import to_pil_image | ||
|
||
img = read_image("gallery/assets/dog1.jpg") | ||
|
||
# Step 1: Initialize model with the best available weights | ||
weights = FCN_ResNet50_Weights.DEFAULT | ||
model = fcn_resnet50(weights=weights) | ||
model.eval() | ||
|
||
# Step 2: Initialize the inference transforms | ||
preprocess = weights.transforms() | ||
|
||
# Step 3: Apply inference preprocessing transforms | ||
batch = preprocess(img).unsqueeze(0) | ||
|
||
# Step 4: Use the model and visualize the prediction | ||
prediction = model(batch)["out"] | ||
normalized_masks = prediction.softmax(dim=1) | ||
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])} | ||
mask = normalized_masks[0, class_to_idx["dog"]] | ||
to_pil_image(mask).show() | ||
|
||
|
||
Table of all available semantic segmentation weights | ||
---------------------------------------------------- | ||
|
||
|
@@ -130,6 +261,41 @@ weights: | |
models/ssd | ||
models/ssdlite | ||
|
||
| | ||
|
||
Here is an example of how to use the pre-trained object detection models: | ||
|
||
.. code:: python | ||
|
||
|
||
from torchvision.io.image import read_image | ||
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights | ||
from torchvision.utils import draw_bounding_boxes | ||
from torchvision.transforms.functional import to_pil_image | ||
|
||
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") | ||
|
||
# Step 1: Initialize model with the best available weights | ||
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT | ||
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) | ||
model.eval() | ||
|
||
# Step 2: Initialize the inference transforms | ||
preprocess = weights.transforms() | ||
|
||
# Step 3: Apply inference preprocessing transforms | ||
batch = [preprocess(img)] | ||
|
||
# Step 4: Use the model and visualize the prediction | ||
prediction = model(batch)[0] | ||
labels = [weights.meta["categories"][i] for i in prediction["labels"]] | ||
box = draw_bounding_boxes(img, boxes=prediction["boxes"], | ||
labels=labels, | ||
colors="red", | ||
width=4, font_size=30) | ||
im = to_pil_image(box.detach()) | ||
im.show() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm getting
But I haven't compiled or updated the nightlies recently, maybe that's why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's strange. It works for me. Can you put the recent in case this is a bug? |
||
|
||
Table of all available Object detection weights | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
|
@@ -191,6 +357,38 @@ pre-trained weights: | |
|
||
models/video_resnet | ||
|
||
| | ||
|
||
Here is an example of how to use the pre-trained video classification models: | ||
|
||
.. code:: python | ||
|
||
|
||
from torchvision.io.video import read_video | ||
from torchvision.models.video import r3d_18, R3D_18_Weights | ||
|
||
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bjuncek Apparently this emits a vague deprecation warning (without a fixed removal date). Does it make sense to remove the warning given that we are still uncertain on what we will do with the specific API? Or perhaps I should modify my example to avoid the warnings? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, feel free to remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #6056 |
||
vid = vid[:32] # optionally shorten duration | ||
|
||
# Step 1: Initialize model with the best available weights | ||
weights = R3D_18_Weights.DEFAULT | ||
model = r3d_18(weights=weights) | ||
model.eval() | ||
|
||
# Step 2: Initialize the inference transforms | ||
preprocess = weights.transforms() | ||
|
||
# Step 3: Apply inference preprocessing transforms | ||
batch = preprocess(vid).unsqueeze(0) | ||
|
||
# Step 4: Use the model and print the predicted category | ||
prediction = model(batch).squeeze(0).softmax(0) | ||
label = prediction.argmax().item() | ||
score = prediction[label].item() | ||
category_name = weights.meta["categories"][label] | ||
print(f"{category_name}: {100 * score}%") | ||
|
||
|
||
Table of all available video classification weights | ||
--------------------------------------------------- | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.