-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
as_strided batching rule #47224
as_strided batching rule #47224
Conversation
This PR adds a batching rule for as_strided. `as_strided` is a really weird operation and I hope that users don't use it very much. Motivation ---------- The motivation for adding a batching rule for as_strided is for batched gradient computation. AsStridedBackward appears in PyTorch when handling view+in-place operations and calls `as_strided`. AsStridedBackward calls as_strided on a fresh tensor with storage_offset equal to 0. We would like to be able to vmap through the backward graph of view+in-place operations to for batched gradient computation, especially because internally we have a number of functions that are implemented as a view+in-place. Alternatives ------------ If we think that as_strided is too crazy to have a batching rule, we could either: - have a flag that controls the autograd view+in-place behavior - require that the input tensor's storage offset must be equal to 0 to make it easier to reason about. I think the batching rule makes sense, so I didn't pursue the alternatives. The batching rule ----------------- ``` y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) ``` The result of the above should be "equivalent" to: - Assume that each x has storage offset equal to xs.storage_offset() (call that S). - Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x. More concretely, this returns a view on `xs`, such that each y[i] has: - sizes: `sizes` - strides: `strides` - storage_offset: offset + i * x.stride(batch_dim) Why the behavior can be weird ----------------------------- The behavior of the batching rule may be different from actually running as_strided in a for-loop because `as_strided` takes in `offset` as a "absolute offset". As an example, consider ``` >>> x = torch.tensor([0., 1., 2., 3., 4.]) >>> z = [x[i].as_strided([1], [1], 1) for i in range(5)] ``` Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))! However, we consider the above for-loop comprehension to be a user error: a user should have written the following if they wanted to use as_strided in a per-sample way: ``` >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)] ``` Test Plan --------- - Added some tests that compare vmap+as_strided to vmap+(the equivalent operator) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit dd4e5ea (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 8 times. |
This PR adds a batching rule for as_strided. `as_strided` is a really weird operation and I hope that users don't use it very much. Motivation ---------- The motivation for adding a batching rule for as_strided is for batched gradient computation. AsStridedBackward appears in PyTorch when handling view+in-place operations and calls `as_strided`. AsStridedBackward calls as_strided on a fresh tensor with storage_offset equal to 0. We would like to be able to vmap through the backward graph of view+in-place operations to for batched gradient computation, especially because internally we have a number of functions that are implemented as a view+in-place. Alternatives ------------ If we think that as_strided is too crazy to have a batching rule, we could either: - have a flag that controls the autograd view+in-place behavior - require that the input tensor's storage offset must be equal to 0 to make it easier to reason about. I think the batching rule makes sense, so I didn't pursue the alternatives. The batching rule ----------------- ``` y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) ``` The result of the above should be "equivalent" to: - Assume that each x has storage offset equal to xs.storage_offset() (call that S). - Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x. More concretely, this returns a view on `xs`, such that each y[i] has: - sizes: `sizes` - strides: `strides` - storage_offset: offset + i * x.stride(batch_dim) Why the behavior can be weird ----------------------------- The behavior of the batching rule may be different from actually running as_strided in a for-loop because `as_strided` takes in `offset` as a "absolute offset". As an example, consider ``` >>> x = torch.tensor([0., 1., 2., 3., 4.]) >>> z = [x[i].as_strided([1], [1], 1) for i in range(5)] ``` Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))! However, we consider the above for-loop comprehension to be a user error: a user should have written the following if they wanted to use as_strided in a per-sample way: ``` >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)] ``` Test Plan --------- - Added some tests that compare vmap+as_strided to vmap+(the equivalent operator) [ghstack-poisoned]
In your example above, you should have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what will be the general structure for the vmap doc. Should we plan to add a note there about as_strided
?
physical_strides.insert( | ||
physical_strides.end(), | ||
physical_view.tensor().strides().begin(), | ||
physical_view.tensor().strides().begin() + num_batch_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no ways the two successive calls to physical_view.tensor().strides()
would return different objects right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no ways the two successive calls to physical_view.tensor().strides() would return different objects right?
They do return different IntArrayRef objects. Now there is a question of how the iterators work. I will change this to use the same strides() to avoid potential bugs
// These memory locations are exactly the same as what we got for [[A]], | ||
// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid. | ||
// | ||
// [[B]] Hand-wavy proof of Claim 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit [[B]]
here is a bit confusing with the [B]
size used for the sizes. Maybe numbers or *
would be better?
// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies: | ||
// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] | ||
// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) | ||
// (the largest-index memory location of xs[i].as_strided(...) must be \leq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You can mention that under the assumption we're making, the lower bound (lowest indexed memory) is trivially within the storage.
// Furthermore, let's say that as a part of being "valid" this as_strided call | ||
// does not return a result that can index memory not indexable by xs[i]. | ||
// | ||
// Assume that there's only one batch dim and it is at the front of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that assumption "without loss of generatlity"? Or does it change when the batch dimension is not at the front or when there is more than one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is without loss of generality.
- If the batch dim is not at the front, we can move it to the front and then go through the contents here (we do not assume a particular stride for xs or the batch dim here).
- For multiple batch dims, this works out similarly. It's just a bit annoying to write because if there are K batch dimensions then we need K indices I0 I1... Ik, and
S*i
becomesoffset + \sum_{i=1}^{k} S_i * I_k
// | ||
// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()). | ||
// Furthermore, let's say that as a part of being "valid" this as_strided call | ||
// does not return a result that can index memory not indexable by xs[i]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add some sanity checks in the batching rule to reduce this risk? Like making sure that the original as_strided wasn't indexing outside of the input?
This PR adds a batching rule for as_strided. `as_strided` is a really weird operation and I hope that users don't use it very much. Motivation ---------- The motivation for adding a batching rule for as_strided is for batched gradient computation. AsStridedBackward appears in PyTorch when handling view+in-place operations and calls `as_strided`. AsStridedBackward calls as_strided on a fresh tensor with storage_offset equal to 0. We would like to be able to vmap through the backward graph of view+in-place operations to for batched gradient computation, especially because internally we have a number of functions that are implemented as a view+in-place. Alternatives ------------ If we think that as_strided is too crazy to have a batching rule, we could either: - have a flag that controls the autograd view+in-place behavior - require that the input tensor's storage offset must be equal to 0 to make it easier to reason about. I think the batching rule makes sense, so I didn't pursue the alternatives. The batching rule ----------------- ``` y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) ``` The result of the above should be "equivalent" to: - Assume that each x has storage offset equal to xs.storage_offset() (call that S). - Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x. More concretely, this returns a view on `xs`, such that each y[i] has: - sizes: `sizes` - strides: `strides` - storage_offset: offset + i * x.stride(batch_dim) Why the behavior can be weird ----------------------------- The behavior of the batching rule may be different from actually running as_strided in a for-loop because `as_strided` takes in `offset` as a "absolute offset". As an example, consider ``` >>> x = torch.tensor([0., 1., 2., 3., 4.]) >>> z = [x[i].as_strided([1], [1], 1) for i in range(5)] ``` Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))! However, we consider the above for-loop comprehension to be a user error: a user should have written the following if they wanted to use as_strided in a per-sample way: ``` >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)] ``` Test Plan --------- - Added some tests that compare vmap+as_strided to vmap+(the equivalent operator) [ghstack-poisoned]
This PR adds a batching rule for as_strided. `as_strided` is a really weird operation and I hope that users don't use it very much. Motivation ---------- The motivation for adding a batching rule for as_strided is for batched gradient computation. AsStridedBackward appears in PyTorch when handling view+in-place operations and calls `as_strided`. AsStridedBackward calls as_strided on a fresh tensor with storage_offset equal to 0. We would like to be able to vmap through the backward graph of view+in-place operations to for batched gradient computation, especially because internally we have a number of functions that are implemented as a view+in-place. Alternatives ------------ If we think that as_strided is too crazy to have a batching rule, we could either: - have a flag that controls the autograd view+in-place behavior - require that the input tensor's storage offset must be equal to 0 to make it easier to reason about. I think the batching rule makes sense, so I didn't pursue the alternatives. The batching rule ----------------- ``` y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) ``` The result of the above should be "equivalent" to: - Assume that each x has storage offset equal to xs.storage_offset() (call that S). - Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x. More concretely, this returns a view on `xs`, such that each y[i] has: - sizes: `sizes` - strides: `strides` - storage_offset: offset + i * x.stride(batch_dim) Why the behavior can be weird ----------------------------- The behavior of the batching rule may be different from actually running as_strided in a for-loop because `as_strided` takes in `offset` as a "absolute offset". As an example, consider ``` >>> x = torch.tensor([0., 1., 2., 3., 4.]) >>> z = [x[i].as_strided([1], [1], 1) for i in range(5)] ``` Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))! However, we consider the above for-loop comprehension to be a user error: a user should have written the following if they wanted to use as_strided in a per-sample way: ``` >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)] ``` Test Plan --------- - Added some tests that compare vmap+as_strided to vmap+(the equivalent operator) [ghstack-poisoned]
I needed to move this PR in the stack. ghstack doesn't like it when PRs are moved, so it turned into a brand-new PR: #47364 |
Stack from ghstack:
This PR adds a batching rule for as_strided.
as_strided
is a really weirdoperation and I hope that users don't use it very much.
Motivation
The motivation for adding a batching rule for as_strided is for
batched gradient computation.
AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls
as_strided
. AsStridedBackward calls as_strided ona fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.
Alternatives
If we think that as_strided is too crazy to have a batching rule, we
could either:
behavior
to make it easier to reason about.
I think the batching rule makes sense, so I didn't pursue the
alternatives.
The batching rule
The result of the above should be "equivalent" to:
(call that S).
More concretely,
this returns a view on
xs
, such that each y[i] has:sizes
strides
Why the behavior can be weird
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because
as_strided
takes inoffset
as a"absolute offset". As an example, consider
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
Test Plan