Skip to content

Commit aeff837

Browse files
author
Vincent Moens
committed
[Feature] broadcast pointwise ops for tensor/tensordict mixed inputs
ghstack-source-id: bbefbb1 Pull Request resolved: #1166
1 parent 646683c commit aeff837

File tree

4 files changed

+404
-6
lines changed

4 files changed

+404
-6
lines changed

docs/source/reference/tensordict.rst

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,114 @@ However, physical storage of PyTorch tensors should not be any different:
109109

110110
MemoryMappedTensor
111111

112+
Pointwise Operations
113+
--------------------
114+
115+
Tensordict supports various pointwise operations, allowing you to perform element-wise computations on the tensors
116+
stored within it. These operations are similar to those performed on regular PyTorch tensors.
117+
118+
Supported Operations
119+
~~~~~~~~~~~~~~~~~~~~
120+
121+
The following pointwise operations are currently supported:
122+
123+
- Left and right addition (`+`)
124+
- Left and right subtraction (`-`)
125+
- Left and right multiplication (`*`)
126+
- Left and right division (`/`)
127+
- Left power (`**`)
128+
129+
Many other ops, like :meth:`~tensordict.TensorDict.clamp`, :meth:`~tensordict.TensorDict.sqrt` etc. are supported.
130+
131+
Performing Pointwise Operations
132+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
134+
You can perform pointwise operations between two Tensordicts or between a Tensordict and a tensor/scalar value.
135+
136+
Example 1: Tensordict-Tensordict Operation
137+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138+
139+
>>> import torch
140+
>>> from tensordict import TensorDict
141+
>>> td1 = TensorDict(
142+
... a=torch.randn(3, 4),
143+
... b=torch.zeros(3, 4, 5),
144+
... c=torch.ones(3, 4, 5, 6),
145+
... batch_size=(3, 4),
146+
... )
147+
>>> td2 = TensorDict(
148+
... a=torch.randn(3, 4),
149+
... b=torch.zeros(3, 4, 5),
150+
... c=torch.ones(3, 4, 5, 6),
151+
... batch_size=(3, 4),
152+
... )
153+
>>> result = td1 * td2
154+
155+
In this example, the * operator is applied element-wise to the corresponding tensors in td1 and td2.
156+
157+
Example 2: Tensordict-Tensor Operation
158+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
159+
160+
>>> import torch
161+
>>> from tensordict import TensorDict
162+
>>> td = TensorDict(
163+
... a=torch.randn(3, 4),
164+
... b=torch.zeros(3, 4, 5),
165+
... c=torch.ones(3, 4, 5, 6),
166+
... batch_size=(3, 4),
167+
... )
168+
>>> tensor = torch.randn(4)
169+
>>> result = td * tensor
170+
171+
ere, the * operator is applied element-wise to each tensor in td and the provided tensor. The tensor is broadcasted to match the shape of each tensor in the Tensordict.
172+
173+
Example 3: Tensordict-Scalar Operation
174+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
175+
176+
>>> import torch
177+
>>> from tensordict import TensorDict
178+
>>> td = TensorDict(
179+
... a=torch.randn(3, 4),
180+
... b=torch.zeros(3, 4, 5),
181+
... c=torch.ones(3, 4, 5, 6),
182+
... batch_size=(3, 4),
183+
... )
184+
>>> scalar = 2.0
185+
>>> result = td * scalar
186+
187+
In this case, the * operator is applied element-wise to each tensor in td and the provided scalar.
188+
189+
Broadcasting Rules
190+
~~~~~~~~~~~~~~~~~~
191+
192+
When performing pointwise operations between a Tensordict and a tensor/scalar, the tensor/scalar is broadcasted to match
193+
the shape of each tensor in the Tensordict: the tensor is broadcast on the left to match the tensordict shape, then
194+
individually broadcast on the right to match the tensors shapes. This follows the standard broadcasting rules used in
195+
PyTorch if one thinks of the ``TensorDict`` as a single tensor instance.
196+
197+
For example, if you have a Tensordict with tensors of shape ``(3, 4)`` and you multiply it by a tensor of shape ``(4,)``,
198+
the tensor will be broadcasted to shape (3, 4) before the operation is applied. If the tensordict contains a tensor of
199+
shape ``(3, 4, 5)``, the tensor used for the multiplication will be broadcast to ``(3, 4, 5)`` on the right for that
200+
multiplication.
201+
202+
If the pointwise operation is executed across multiple tensordicts and their batch-size differ, they will be
203+
broadcasted to a common shape.
204+
205+
Efficiency of pointwise operations
206+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
207+
208+
When possible, ``torch._foreach_<op>`` fused kernels will be used to speed up the computation of the pointwise
209+
operation.
210+
211+
Handling Missing Entries
212+
~~~~~~~~~~~~~~~~~~~~~~~~
213+
214+
When performing pointwise operations between two Tensordicts, they must have the same keys.
215+
Some operations, like :meth:`~tensordict.TensorDict.add`, have a ``default`` keyword argument that can be used
216+
to operate with tensordict with exclusive entries.
217+
If ``default=None`` (the default), the two Tensordicts must have exactly matching key sets.
218+
If ``default="intersection"``, only the intersecting key sets will be considered, and other keys will be ignored.
219+
In all other cases, ``default`` will be used for all missing entries on both sides of the operation.
112220

113221
Utils
114222
-----

tensordict/_td.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_is_leaf_nontensor,
3131
_is_tensor_collection,
3232
_load_metadata,
33+
_maybe_broadcast_other,
3334
_NESTED_TENSORS_AS_LISTS,
3435
_register_tensor_class,
3536
BEST_ATTEMPT_INPLACE,
@@ -611,6 +612,7 @@ def _quick_set(swap_dict, swap_td):
611612
else:
612613
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))
613614

615+
@_maybe_broadcast_other("__ne__")
614616
def __ne__(self, other: Any) -> T | bool:
615617
if is_tensorclass(other):
616618
return other != self
@@ -635,6 +637,7 @@ def __ne__(self, other: Any) -> T | bool:
635637
)
636638
return True
637639

640+
@_maybe_broadcast_other("__xor__")
638641
def __xor__(self, other: Any) -> T | bool:
639642
if is_tensorclass(other):
640643
return other ^ self
@@ -659,6 +662,7 @@ def __xor__(self, other: Any) -> T | bool:
659662
)
660663
return True
661664

665+
@_maybe_broadcast_other("__or__")
662666
def __or__(self, other: Any) -> T | bool:
663667
if is_tensorclass(other):
664668
return other | self
@@ -683,6 +687,7 @@ def __or__(self, other: Any) -> T | bool:
683687
)
684688
return False
685689

690+
@_maybe_broadcast_other("__eq__")
686691
def __eq__(self, other: Any) -> T | bool:
687692
if is_tensorclass(other):
688693
return other == self
@@ -705,6 +710,7 @@ def __eq__(self, other: Any) -> T | bool:
705710
)
706711
return False
707712

713+
@_maybe_broadcast_other("__ge__")
708714
def __ge__(self, other: Any) -> T | bool:
709715
if is_tensorclass(other):
710716
return other <= self
@@ -727,6 +733,7 @@ def __ge__(self, other: Any) -> T | bool:
727733
)
728734
return False
729735

736+
@_maybe_broadcast_other("__gt__")
730737
def __gt__(self, other: Any) -> T | bool:
731738
if is_tensorclass(other):
732739
return other < self
@@ -749,6 +756,7 @@ def __gt__(self, other: Any) -> T | bool:
749756
)
750757
return False
751758

759+
@_maybe_broadcast_other("__le__")
752760
def __le__(self, other: Any) -> T | bool:
753761
if is_tensorclass(other):
754762
return other >= self
@@ -771,6 +779,7 @@ def __le__(self, other: Any) -> T | bool:
771779
)
772780
return False
773781

782+
@_maybe_broadcast_other("__lt__")
774783
def __lt__(self, other: Any) -> T | bool:
775784
if is_tensorclass(other):
776785
return other > self

0 commit comments

Comments
 (0)