You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/reference/tensordict.rst
+108Lines changed: 108 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -109,6 +109,114 @@ However, physical storage of PyTorch tensors should not be any different:
109
109
110
110
MemoryMappedTensor
111
111
112
+
Pointwise Operations
113
+
--------------------
114
+
115
+
Tensordict supports various pointwise operations, allowing you to perform element-wise computations on the tensors
116
+
stored within it. These operations are similar to those performed on regular PyTorch tensors.
117
+
118
+
Supported Operations
119
+
~~~~~~~~~~~~~~~~~~~~
120
+
121
+
The following pointwise operations are currently supported:
122
+
123
+
- Left and right addition (`+`)
124
+
- Left and right subtraction (`-`)
125
+
- Left and right multiplication (`*`)
126
+
- Left and right division (`/`)
127
+
- Left power (`**`)
128
+
129
+
Many other ops, like :meth:`~tensordict.TensorDict.clamp`, :meth:`~tensordict.TensorDict.sqrt` etc. are supported.
130
+
131
+
Performing Pointwise Operations
132
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133
+
134
+
You can perform pointwise operations between two Tensordicts or between a Tensordict and a tensor/scalar value.
135
+
136
+
Example 1: Tensordict-Tensordict Operation
137
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138
+
139
+
>>> import torch
140
+
>>> from tensordict import TensorDict
141
+
>>> td1 = TensorDict(
142
+
... a=torch.randn(3, 4),
143
+
... b=torch.zeros(3, 4, 5),
144
+
... c=torch.ones(3, 4, 5, 6),
145
+
... batch_size=(3, 4),
146
+
... )
147
+
>>> td2 = TensorDict(
148
+
... a=torch.randn(3, 4),
149
+
... b=torch.zeros(3, 4, 5),
150
+
... c=torch.ones(3, 4, 5, 6),
151
+
... batch_size=(3, 4),
152
+
... )
153
+
>>> result = td1 * td2
154
+
155
+
In this example, the * operator is applied element-wise to the corresponding tensors in td1 and td2.
156
+
157
+
Example 2: Tensordict-Tensor Operation
158
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
159
+
160
+
>>> import torch
161
+
>>> from tensordict import TensorDict
162
+
>>> td = TensorDict(
163
+
... a=torch.randn(3, 4),
164
+
... b=torch.zeros(3, 4, 5),
165
+
... c=torch.ones(3, 4, 5, 6),
166
+
... batch_size=(3, 4),
167
+
... )
168
+
>>> tensor = torch.randn(4)
169
+
>>> result = td * tensor
170
+
171
+
ere, the * operator is applied element-wise to each tensor in td and the provided tensor. The tensor is broadcasted to match the shape of each tensor in the Tensordict.
172
+
173
+
Example 3: Tensordict-Scalar Operation
174
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
175
+
176
+
>>> import torch
177
+
>>> from tensordict import TensorDict
178
+
>>> td = TensorDict(
179
+
... a=torch.randn(3, 4),
180
+
... b=torch.zeros(3, 4, 5),
181
+
... c=torch.ones(3, 4, 5, 6),
182
+
... batch_size=(3, 4),
183
+
... )
184
+
>>> scalar =2.0
185
+
>>> result = td * scalar
186
+
187
+
In this case, the * operator is applied element-wise to each tensor in td and the provided scalar.
188
+
189
+
Broadcasting Rules
190
+
~~~~~~~~~~~~~~~~~~
191
+
192
+
When performing pointwise operations between a Tensordict and a tensor/scalar, the tensor/scalar is broadcasted to match
193
+
the shape of each tensor in the Tensordict: the tensor is broadcast on the left to match the tensordict shape, then
194
+
individually broadcast on the right to match the tensors shapes. This follows the standard broadcasting rules used in
195
+
PyTorch if one thinks of the ``TensorDict`` as a single tensor instance.
196
+
197
+
For example, if you have a Tensordict with tensors of shape ``(3, 4)`` and you multiply it by a tensor of shape ``(4,)``,
198
+
the tensor will be broadcasted to shape (3, 4) before the operation is applied. If the tensordict contains a tensor of
199
+
shape ``(3, 4, 5)``, the tensor used for the multiplication will be broadcast to ``(3, 4, 5)`` on the right for that
200
+
multiplication.
201
+
202
+
If the pointwise operation is executed across multiple tensordicts and their batch-size differ, they will be
203
+
broadcasted to a common shape.
204
+
205
+
Efficiency of pointwise operations
206
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
207
+
208
+
When possible, ``torch._foreach_<op>`` fused kernels will be used to speed up the computation of the pointwise
209
+
operation.
210
+
211
+
Handling Missing Entries
212
+
~~~~~~~~~~~~~~~~~~~~~~~~
213
+
214
+
When performing pointwise operations between two Tensordicts, they must have the same keys.
215
+
Some operations, like :meth:`~tensordict.TensorDict.add`, have a ``default`` keyword argument that can be used
216
+
to operate with tensordict with exclusive entries.
217
+
If ``default=None`` (the default), the two Tensordicts must have exactly matching key sets.
218
+
If ``default="intersection"``, only the intersecting key sets will be considered, and other keys will be ignored.
219
+
In all other cases, ``default`` will be used for all missing entries on both sides of the operation.
0 commit comments