|
2 | 2 |
|
3 | 3 | """ |
4 | 4 | (Prototype) MaskedTensor Overview |
5 | | -================================= |
| 5 | +********************************* |
6 | 6 | """ |
7 | 7 |
|
8 | 8 | ###################################################################### |
|
14 | 14 | # * use any masked semantics (for example, variable length tensors, nan* operators, etc.) |
15 | 15 | # * differentiation between 0 and NaN gradients |
16 | 16 | # * various sparse applications (see tutorial below) |
17 | | -# |
| 17 | +# |
18 | 18 | # For a more detailed introduction on what MaskedTensors are, please find the |
19 | 19 | # `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__. |
20 | | -# |
| 20 | +# |
21 | 21 | # Using MaskedTensor |
22 | | -# ++++++++++++++++++ |
| 22 | +# ================== |
| 23 | +# |
| 24 | +# In this section we discuss how to use MaskedTensor including how to construct, access, the data |
| 25 | +# and mask, as well as indexing and slicing. |
| 26 | +# |
| 27 | +# Preparation |
| 28 | +# ----------- |
23 | 29 | # |
| 30 | +# We'll begin by doing the necessary setup for the tutorial: |
| 31 | +# |
| 32 | + |
| 33 | +import torch |
| 34 | +from torch.masked import masked_tensor, as_masked_tensor |
| 35 | +import warnings |
| 36 | + |
| 37 | +# Disable prototype warnings and such |
| 38 | +warnings.filterwarnings(action='ignore', category=UserWarning) |
| 39 | + |
| 40 | +###################################################################### |
24 | 41 | # Construction |
25 | 42 | # ------------ |
26 | | -# |
| 43 | +# |
27 | 44 | # There are a few different ways to construct a MaskedTensor: |
28 | 45 | # |
29 | 46 | # * The first way is to directly invoke the MaskedTensor class |
|
52 | 69 | # as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns: |
53 | 70 | # |
54 | 71 |
|
55 | | -import torch |
56 | | -from torch.masked import masked_tensor, as_masked_tensor |
57 | | - |
58 | 72 | data = torch.arange(24).reshape(2, 3, 4) |
59 | 73 | mask = data % 2 == 0 |
60 | 74 |
|
61 | | -print("data\n", data) |
62 | | -print("mask\n", mask) |
| 75 | +print("data:\n", data) |
| 76 | +print("mask:\n", mask) |
| 77 | + |
| 78 | +###################################################################### |
| 79 | +# |
63 | 80 |
|
64 | 81 | # float is used for cleaner visualization when being printed |
65 | 82 | mt = masked_tensor(data.float(), mask) |
66 | 83 |
|
67 | | -print ("mt[0]:\n", mt[0]) |
68 | | -print ("mt[:, :, 2:4]", mt[:, :, 2:4]) |
| 84 | +print("mt[0]:\n", mt[0]) |
| 85 | +print("mt[:, :, 2:4]:\n", mt[:, :, 2:4]) |
69 | 86 |
|
70 | 87 | ###################################################################### |
71 | 88 | # Why is MaskedTensor useful? |
72 | | -# +++++++++++++++++++++++++++ |
| 89 | +# =========================== |
73 | 90 | # |
74 | 91 | # Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen |
75 | 92 | # instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings |
|
90 | 107 | # |
91 | 108 | # :class:`MaskedTensor` is the perfect solution for this! |
92 | 109 | # |
93 | | -# :func:`torch.where` |
94 | | -# ^^^^^^^^^^^^^^^^^^^ |
| 110 | +# torch.where |
| 111 | +# ^^^^^^^^^^^ |
95 | 112 | # |
96 | 113 | # In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations |
97 | 114 | # can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0 |
|
121 | 138 | # The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where` |
122 | 139 | # to mask out elements instead of setting them to zero. |
123 | 140 | # |
124 | | -# Another :func:`torch.where` |
125 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 141 | +# Another torch.where |
| 142 | +# ^^^^^^^^^^^^^^^^^^^ |
126 | 143 | # |
127 | 144 | # `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example. |
128 | 145 | # |
|
174 | 191 | x = torch.tensor([1., 1.], requires_grad=True) |
175 | 192 | div = torch.tensor([0., 1.]) |
176 | 193 | y = x/div # => y is [inf, 1] |
177 | | - >>> |
178 | 194 | mask = (div != 0) # => mask is [0, 1] |
179 | 195 | loss = as_masked_tensor(y, mask) |
180 | 196 | loss.sum().backward() |
181 | 197 | x.grad |
182 | 198 |
|
183 | 199 | ###################################################################### |
184 | 200 | # :func:`torch.nansum` and :func:`torch.nanmean` |
185 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 201 | +# ---------------------------------------------- |
186 | 202 | # |
187 | 203 | # In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__, |
188 | 204 | # the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly. |
|
213 | 229 | # Safe Softmax |
214 | 230 | # ------------ |
215 | 231 | # |
216 | | -# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_ |
| 232 | +# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__ |
217 | 233 | # that arises frequently. In a nutshell, if there is an entire batch that is "masked out" |
218 | 234 | # or consists entirely of padding (which, in the softmax case, translates to being set `-inf`), |
219 | 235 | # then this will result in NaNs, which can lead to training divergence. |
|
247 | 263 |
|
248 | 264 | ###################################################################### |
249 | 265 | # Implementing missing torch.nan* operators |
250 | | -# -------------------------------------------------------------------------------------------------------------- |
| 266 | +# ----------------------------------------- |
251 | 267 | # |
252 | | -# In `Issue 61474 <<https://github.com/pytorch/pytorch/issues/61474>`__, |
| 268 | +# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__, |
253 | 269 | # there is a request to add additional operators to cover the various `torch.nan*` applications, |
254 | 270 | # such as ``torch.nanmax``, ``torch.nanmin``, etc. |
255 | 271 | # |
256 | 272 | # In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional |
257 | | -# operators, we propose using :class:`MaskedTensor`s instead. Since |
258 | | -# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point: |
| 273 | +# operators, we propose using :class:`MaskedTensor` instead. |
| 274 | +# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, |
| 275 | +# we can use it as a comparison point: |
259 | 276 | # |
260 | 277 |
|
261 | 278 | x = torch.arange(16).float() |
262 | 279 | y = x * x.fmod(4) |
263 | 280 | z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros |
264 | 281 |
|
265 | | -print("y:\n, y") |
| 282 | +###################################################################### |
| 283 | +# |
| 284 | +print("y:\n", y) |
266 | 285 | # z is just y with the zeros replaced with nan's |
267 | 286 | print("z:\n", z) |
| 287 | + |
| 288 | +###################################################################### |
| 289 | +# |
| 290 | + |
268 | 291 | print("y.mean():\n", y.mean()) |
269 | 292 | print("z.nanmean():\n", z.nanmean()) |
270 | 293 | # MaskedTensor successfully ignores the 0's |
|
296 | 319 | # This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value. |
297 | 320 | # |
298 | 321 | # Conclusion |
299 | | -# ++++++++++ |
| 322 | +# ========== |
300 | 323 | # |
301 | 324 | # In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their |
302 | 325 | # value through a series of examples and issues that they've helped resolve. |
303 | 326 | # |
304 | 327 | # Further Reading |
305 | | -# +++++++++++++++ |
| 328 | +# =============== |
306 | 329 | # |
307 | 330 | # To continue learning more, you can find our |
308 | 331 | # `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__ |
|
0 commit comments