Skip to content

Commit

Permalink
Datapoint -> VisionTensor; datapoint[s] -> vision_tensor[s]
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 29, 2023
1 parent 655ebdb commit e27ac58
Show file tree
Hide file tree
Showing 88 changed files with 1,163 additions and 1,135 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(self, src_dir):
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_datapoints.py",
"plot_custom_datapoints.py",
"plot_vision_tensors.py",
"plot_custom_vision_tensors.py",
]

def __call__(self, filename):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference

transforms
datapoints
vision_tensors
models
datasets
utils
Expand Down
8 changes: 4 additions & 4 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification).
.. code:: python
# Detection (re-using imports and transforms from above)
from torchvision import datapoints
from torchvision import vision_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
bboxes = vision_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, bboxes = transforms(img, bboxes)
Expand Down Expand Up @@ -183,8 +183,8 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.

The functionals support PIL images, pure tensors, or :ref:`datapoints
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
The functionals support PIL images, pure tensors, or :ref:`vision_tensors
<vision_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.

.. note::
Expand Down
14 changes: 7 additions & 7 deletions docs/source/datapoints.rst → docs/source/vision_tensors.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
.. _datapoints:
.. _vision_tensors:

Datapoints
==========
VisionTensors
=============

.. currentmodule:: torchvision.datapoints
.. currentmodule:: torchvision.vision_tensors

Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
VisionTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate datapoints directly and can simply rely on dataset wrapping -
need to manipulate vision_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.

.. autosummary::
Expand All @@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
VisionTensor
set_return_type
wrap
4 changes: 2 additions & 2 deletions gallery/others/plot_video_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
print("PTS for first five frames ", ptss[:5])
print("Total number of frames: ", len(frames))
approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
print("Approx total number of datapoints we can expect: ", approx_nf)
print("Approx total number of vision_tensors we can expect: ", approx_nf)
print("Read data size: ", frames[0].size(0) * len(frames))

# %%
Expand Down Expand Up @@ -170,7 +170,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()


# Total number of frames should be 327 for video and 523264 datapoints for audio
# Total number of frames should be 327 for video and 523264 vision_tensors for audio
vf, af, info, meta = example_read_video(video)
print(vf.size(), af.size())

Expand Down
4 changes: 2 additions & 2 deletions gallery/transforms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints
from torchvision import vision_tensors
from torchvision.transforms.v2 import functional as F


Expand All @@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs):
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes):
elif isinstance(target, vision_tensors.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
Expand Down
10 changes: 5 additions & 5 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# %%
import torch
from torchvision import datapoints
from torchvision import vision_tensors
from torchvision.transforms import v2


Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
bboxes = vision_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
Expand All @@ -74,9 +74,9 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# While working with vision_tensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
# :ref:`vision_tensor_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
Expand Down Expand Up @@ -111,7 +111,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# based on the **class** of the entries, as all vision_tensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
"""
=====================================
How to write your own Datapoint class
=====================================
========================================
How to write your own VisionTensor class
========================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_vision_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_vision_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
write your own vision_tensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
:ref:`sphx_glr_auto_examples_transforms_plot_vision_tensors.py`.
"""

# %%
import torch
from torchvision import datapoints
from torchvision import vision_tensors
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# :class:`~torchvision.vision_tensors.VisionTensor` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
# :class:`~torchvision.vision_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/vision_tensors/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
class MyVisionTensor(vision_tensors.VisionTensor):
pass


my_dp = MyDatapoint([1, 2, 3])
my_dp = MyVisionTensor([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# Now that we have defined our custom VisionTensor class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
# operation of our MyVisionTensor class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
@F.register_kernel(functional="hflip", vision_tensor_cls=MyVisionTensor)
def hflip_my_vision_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return vision_tensors.wrap(out, like=my_dp)


# %%
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# To understand why :func:`~torchvision.vision_tensors.wrap` is used, see
# :ref:`vision_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
Expand All @@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# ``@register_kernel(functional=F.hflip, ...)``.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
# ``MyVisionTensor`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
my_dp = MyVisionTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
Expand Down Expand Up @@ -102,10 +102,10 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp): # noqa
def hflip_my_vision_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return vision_tensors.wrap(out, like=my_dp)


# %%
Expand Down
6 changes: 3 additions & 3 deletions gallery/transforms/plot_transforms_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import torch.utils.data

from torchvision import models, datasets, datapoints
from torchvision import models, datasets, vision_tensors
from torchvision.transforms import v2

torch.manual_seed(0)
Expand Down Expand Up @@ -72,7 +72,7 @@
# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# are :ref:`VisionTensors <what_are_vision_tensors>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
Expand Down Expand Up @@ -103,7 +103,7 @@
[
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomZoomOut(fill={vision_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
Expand Down
Loading

0 comments on commit e27ac58

Please sign in to comment.