Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

as_strided batching rule #47224

Closed
wants to merge 4 commits into from
Closed

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Nov 2, 2020

Stack from ghstack:

This PR adds a batching rule for as_strided. as_strided is a really weird
operation and I hope that users don't use it very much.

Motivation

The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls as_strided. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives

If we think that as_strided is too crazy to have a batching rule, we
could either:

  • have a flag that controls the autograd view+in-place
    behavior
  • require that the input tensor's storage offset must be equal to 0
    to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule

y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)

The result of the above should be "equivalent" to:

  • Assume that each x has storage offset equal to xs.storage_offset()
    (call that S).
  • Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on xs, such that each y[i] has:

  • sizes: sizes
  • strides: strides
  • storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird

The behavior of the batching rule may be different from actually running
as_strided in a for-loop because as_strided takes in offset as a
"absolute offset". As an example, consider

>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 1) for i in range(5)]

Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:

>>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)]

Test Plan

  • Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.

Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider

```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 1) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)]
```

Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Nov 2, 2020

💊 CI failures summary and remediations

As of commit dd4e5ea (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 8 times.

This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.

Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider

```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 1) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)]
```

Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

[ghstack-poisoned]
@zou3519 zou3519 requested review from albanD and ezyang November 2, 2020 22:48
@albanD
Copy link
Collaborator

albanD commented Nov 3, 2020

In your example above, you should have z[i] == torch.tensor([1.]) (and not 0.) right? As the offset you give is 1?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what will be the general structure for the vmap doc. Should we plan to add a note there about as_strided?

physical_strides.insert(
physical_strides.end(),
physical_view.tensor().strides().begin(),
physical_view.tensor().strides().begin() + num_batch_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no ways the two successive calls to physical_view.tensor().strides() would return different objects right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no ways the two successive calls to physical_view.tensor().strides() would return different objects right?

They do return different IntArrayRef objects. Now there is a question of how the iterators work. I will change this to use the same strides() to avoid potential bugs

// These memory locations are exactly the same as what we got for [[A]],
// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
//
// [[B]] Hand-wavy proof of Claim 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit [[B]] here is a bit confusing with the [B] size used for the sizes. Maybe numbers or * would be better?

// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
// (the largest-index memory location of xs[i].as_strided(...) must be \leq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can mention that under the assumption we're making, the lower bound (lowest indexed memory) is trivially within the storage.

// Furthermore, let's say that as a part of being "valid" this as_strided call
// does not return a result that can index memory not indexable by xs[i].
//
// Assume that there's only one batch dim and it is at the front of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that assumption "without loss of generatlity"? Or does it change when the batch dimension is not at the front or when there is more than one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is without loss of generality.

  • If the batch dim is not at the front, we can move it to the front and then go through the contents here (we do not assume a particular stride for xs or the batch dim here).
  • For multiple batch dims, this works out similarly. It's just a bit annoying to write because if there are K batch dimensions then we need K indices I0 I1... Ik, and S*i becomes offset + \sum_{i=1}^{k} S_i * I_k

//
// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
// Furthermore, let's say that as a part of being "valid" this as_strided call
// does not return a result that can index memory not indexable by xs[i].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some sanity checks in the batching rule to reduce this risk? Like making sure that the original as_strided wasn't indexing outside of the input?

This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.

Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider

```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 1) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)]
```

Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

[ghstack-poisoned]
This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.

Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider

```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 1) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset()) for i in range(4)]
```

Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

[ghstack-poisoned]
@zou3519
Copy link
Contributor Author

zou3519 commented Nov 4, 2020

I needed to move this PR in the stack. ghstack doesn't like it when PRs are moved, so it turned into a brand-new PR: #47364

@zou3519 zou3519 closed this Nov 4, 2020
@facebook-github-bot facebook-github-bot deleted the gh/zou3519/323/head branch December 5, 2020 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants