-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched gradient support for view+inplace operations #47227
Conversation
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: 4d15a50e14c89d64c78e2b0a04ddc4cbf4ca659a Pull Request resolved: #47227
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
The interesting bits are in the stack below this PR but I'll tag you as reviewer when they are ready (I am just waiting for some tests to pass) |
Not sure what the black magic in the rest of the stack is. But this one looks surprisingly clean for sure :D |
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: 5995854c2ef351c0a41268fc6451181c32cf441b Pull Request resolved: #47227
💊 CI failures summary and remediationsAs of commit c5c558f (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/1)Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)
|
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: ce000e3eab210004832f9f6ac8b5fd9966693d84 Pull Request resolved: #47227
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: 5e434abc230adad0427af592fce1b3362f3527aa Pull Request resolved: #47227
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. Differential Revision: [D24741687](https://our.internmc.facebook.com/intern/diff/D24741687) [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: 9f63c2be609239a8278391b1cd0568b1cab9c17d Pull Request resolved: #47227
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. Differential Revision: [D24741687](https://our.internmc.facebook.com/intern/diff/D24741687) [ghstack-poisoned]
Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. ghstack-source-id: 2dfea97e7f48ffd9577bf0a4a61ace81dba565d9 Pull Request resolved: #47227
Stack from ghstack:
Motivation
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.
Approach
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,
base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.
To support vmap over CopySlices and AsStridedBackward:
new_empty_strided
instead ofempty_strided
in CopySlicesso that the batch dims get propagated
new_zeros
inside AsStridedBackward so that the batch dims getpropagated.
Test Plan
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.
Differential Revision: D24741687