Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched gradient support for view+inplace operations #47227

Closed
wants to merge 10 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Nov 2, 2020

Stack from ghstack:

Motivation

We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach

view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:

  • We use new_empty_strided instead of empty_strided in CopySlices
    so that the batch dims get propagated
  • We use new_zeros inside AsStridedBackward so that the batch dims get
    propagated.

Test Plan

  • New tests. When we get closer to having most operations support batched
    grad computation via vmap, I'd like to add it as an option to gradcheck
    and turn it on for our tests.

Differential Revision: D24741687

Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 2, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: 4d15a50e14c89d64c78e2b0a04ddc4cbf4ca659a
Pull Request resolved: #47227
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

test/test_vmap.py Outdated Show resolved Hide resolved
@zou3519
Copy link
Contributor Author

zou3519 commented Nov 2, 2020

lgtm!

The interesting bits are in the stack below this PR but I'll tag you as reviewer when they are ready (I am just waiting for some tests to pass)

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2020

Not sure what the black magic in the rest of the stack is. But this one looks surprisingly clean for sure :D

Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 2, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: 5995854c2ef351c0a41268fc6451181c32cf441b
Pull Request resolved: #47227
@dr-ci
Copy link

dr-ci bot commented Nov 2, 2020

💊 CI failures summary and remediations

As of commit c5c558f (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/1)

Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)

fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878
+ sleep 2 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 4 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 
+ sleep 8 
+ git submodule update --init --recursive 
fatal: reference is not a tree: cd5a9b73c3028d2496666201588111a8c8d84878 
Unable to checkout 'cd5a9b73c3028d2496666201588111a8c8d84878' in submodule path 'third_party/nccl/nccl' 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 43 times.

Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 3, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: ce000e3eab210004832f9f6ac8b5fd9966693d84
Pull Request resolved: #47227
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 4, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: 5e434abc230adad0427af592fce1b3362f3527aa
Pull Request resolved: #47227
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

Differential Revision: [D24741687](https://our.internmc.facebook.com/intern/diff/D24741687)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 5, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: 9f63c2be609239a8278391b1cd0568b1cab9c17d
Pull Request resolved: #47227
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

Differential Revision: [D24741687](https://our.internmc.facebook.com/intern/diff/D24741687)

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Nov 9, 2020
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

ghstack-source-id: 2dfea97e7f48ffd9577bf0a4a61ace81dba565d9
Pull Request resolved: #47227
@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in 57dcb04.

@facebook-github-bot facebook-github-bot deleted the gh/zou3519/326/head branch November 14, 2020 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants