Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batching rule for torch.clone(tensor, torch.contiguous_format) #47365

Closed
wants to merge 5 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Nov 4, 2020

Stack from ghstack:

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:

>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??

There are two possible ways to interpret this operation (if we choose to
allow it to succeed):

  1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
  2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
    not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:

  • there are places in the codebase where we perform .contiguous() and
    then pass the result to an operator op that only accepts contiguous
    inputs.
  • If we vmap over such code and don't have a batching rule implemented for
    op, then we want the per-samples to be contiguous so that
    when op goes through the vmap fallback, it receives contiguous
    per-samples.

(1) is the approach we've selected for this PR.

Motivation

To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:

auto res = (*fn)({ grad_slice.clone(at::MemoryFormat::Contiguous) });

Alternatives

  • Implementing (2) is difficult in the current design because vmap is
    allowed to move batch dimensions to the front of the tensor. We would
    need some global information about the in_dims and out_dims passed to
    vmap.
  • We could also error out if someone calls clone(contiguous_format) and
    the batch dims are not at the front. This would resolve the ambiguity at
    the cost of limiting what vmap can do.

Future Work

  • Add to a "vmap gotchas" page the behavior of contiguous_format.
  • Implement is_contiguous, Tensor.contiguous() with the same semantics.
    Those currently error out.

Test Plan

  • new tests

Differential Revision: D24741683

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Nov 4, 2020

💊 CI failures summary and remediations

As of commit 84951db (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/1)

Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)

fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878
+ sleep 2 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 4 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 8 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 19 times.

…_format)"

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

[ghstack-poisoned]
@zou3519 zou3519 requested review from albanD and ezyang November 4, 2020 18:43
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

…_format)"

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Nov 5, 2020

Thank you for the PR descriptions, they are super clear.

@ezyang
Copy link
Contributor

ezyang commented Nov 5, 2020

I'm trying to think if there is any justification for why (2) is right and pulling up a blank. vmap is all about making local per-sample decisions which then can be lifted into global batched variant. So I feel like you are obligated to preserve the local invariant that y[:,0] is contiguous. If you want global contiguity, then as long as you use the option to preserve memory format, you ought to get it.

@zou3519
Copy link
Contributor Author

zou3519 commented Nov 5, 2020

I'm trying to think if there is any justification for why (2) is right and pulling up a blank. vmap is all about making local per-sample decisions which then can be lifted into global batched variant. So I feel like you are obligated to preserve the local invariant that y[:,0] is contiguous. If you want global contiguity, then as long as you use the option to preserve memory format, you ought to get it.

I agree, I think I am overthinking the problem

…_format)"

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

Differential Revision: [D24741683](https://our.internmc.facebook.com/intern/diff/D24741683)

[ghstack-poisoned]
…_format)"

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
https://github.com/pytorch/pytorch/blob/e4bc785dd57b15ae091eb8e8ca71a604da9b3fb2/torch/csrc/autograd/functions/tensor.cpp#L93

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

Differential Revision: [D24741683](https://our.internmc.facebook.com/intern/diff/D24741683)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 9, 2020
Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan:
- new tests

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in ead86b2.

zou3519 added a commit that referenced this pull request Nov 9, 2020
Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan:
- new tests

Differential Revision: [D24840975](https://our.internmc.facebook.com/intern/diff/D24840975)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 10, 2020
Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan:
- new tests

Differential Revision: [D24840975](https://our.internmc.facebook.com/intern/diff/D24840975)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 10, 2020
Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan:
- new tests

Differential Revision: [D24840975](https://our.internmc.facebook.com/intern/diff/D24840975)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 11, 2020
Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan:
- new tests

Differential Revision: [D24840975](https://our.internmc.facebook.com/intern/diff/D24840975)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Nov 11, 2020
Summary:
Pull Request resolved: #47621

Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan: - new tests

Reviewed By: Chillee, anjali411

Differential Revision: D24840975

Pulled By: zou3519

fbshipit-source-id: 4d86dbf11e2eec45f3f08300ae3f2d79615bb99d
@facebook-github-bot facebook-github-bot deleted the gh/zou3519/328/head branch November 13, 2020 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants