-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Use new_zeros in evenly_distribute_backward #46674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary ------- This adds batched gradient support (i.e., vmap through the gradient formulas) for Tensor.max(), Tensor.min(), Tensor.median() that have evenly_distribute_backward as their backward formula. Previously, the plan was to register incompatible gradient formulas as backward operators (see #44052). However, it turns out that we can just use `new_zeros` to get around some incompatible gradient formulas (see next section for discussion). Context: the vmap+inplace problem --------------------------------- A lot of backwards functions are incompatible with BatchedTensor due to using in-place operations. Sometimes we can allow the in-place operations, but other times we can't. For example, consider select_backward: ``` Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.select(dim, index).copy_(grad); return grad_input; } ``` and consider the following code: ``` x = torch.randn(5, requires_grad=True) def select_grad(v): torch.autograd.grad(x[0], x, v) vs = torch.randn(B0) batched_grads = vmap(select_grad)(vs) ``` For the batched gradient use case, grad is a BatchedTensor. The physical version of grad has size (B0,). However, select_backward creates a grad_input of shape (5), and tries to copy grad to a slice of it. Up until now, the proposal to handle this has been to register these backward formulas as operators so that vmap doesn’t actually see the `copy_` calls (see #44052). However, it turns out we can actually just use `new_zeros` to construct a new Tensor that has the same "batched-ness" as grad: ``` auto grad_input = grad.new_zeros(input_sizes); grad_input.select(dim, index).copy_(grad); ``` We should use this for simple backward functions. For more complicated backward functions where this solution doesn't work, we should register those as operators. Alternatives ------------ Option 2: Register `evenly_distribute_backward` as an operator and have the vmap fallback run it in a loop. - This requires more LOC changes. - Furthermore, we'd have to write an efficient batching rule for `evenly_distribute_backward` in the future. - If we use `new_zeros` instead, we don't need to write an efficient batching rule for `evenly_distribute_backward` as long as the constituents of `evenly_distributed_backward` have efficient batching rules. Option 3: Have factory functions perform differently if they are called inside vmap. - For example, `at::zeros(3, 5)` could return a Tensor of shape `(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1. This requires maintaining some global and/or thread-local state about the size of the dims being vmapped over which can be tricky. And more... Future ------ - I will undo some of the work I’ve done in the past to move backward functions to being operators (#44052, #44408). The simpler backward functions (like select backward) can just use Tensor.new_zeros. I apologize for the thrashing. - Include a NOTE about the vmap+inplace problem somewhere in the codebase. I don't have a good idea of where to put it at the moment. Test Plan --------- - New tests [ghstack-poisoned]
Summary ------- This adds batched gradient support (i.e., vmap through the gradient formulas) for Tensor.max(), Tensor.min(), Tensor.median() that have evenly_distribute_backward as their backward formula. Previously, the plan was to register incompatible gradient formulas as backward operators (see #44052). However, it turns out that we can just use `new_zeros` to get around some incompatible gradient formulas (see next section for discussion). Context: the vmap+inplace problem --------------------------------- A lot of backwards functions are incompatible with BatchedTensor due to using in-place operations. Sometimes we can allow the in-place operations, but other times we can't. For example, consider select_backward: ``` Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.select(dim, index).copy_(grad); return grad_input; } ``` and consider the following code: ``` x = torch.randn(5, requires_grad=True) def select_grad(v): torch.autograd.grad(x[0], x, v) vs = torch.randn(B0) batched_grads = vmap(select_grad)(vs) ``` For the batched gradient use case, grad is a BatchedTensor. The physical version of grad has size (B0,). However, select_backward creates a grad_input of shape (5), and tries to copy grad to a slice of it. Up until now, the proposal to handle this has been to register these backward formulas as operators so that vmap doesn’t actually see the `copy_` calls (see #44052). However, it turns out we can actually just use `new_zeros` to construct a new Tensor that has the same "batched-ness" as grad: ``` auto grad_input = grad.new_zeros(input_sizes); grad_input.select(dim, index).copy_(grad); ``` We should use this for simple backward functions. For more complicated backward functions where this solution doesn't work, we should register those as operators. Alternatives ------------ Option 2: Register `evenly_distribute_backward` as an operator and have the vmap fallback run it in a loop. - This requires more LOC changes. - Furthermore, we'd have to write an efficient batching rule for `evenly_distribute_backward` in the future. - If we use `new_zeros` instead, we don't need to write an efficient batching rule for `evenly_distribute_backward` as long as the constituents of `evenly_distributed_backward` have efficient batching rules. Option 3: Have factory functions perform differently if they are called inside vmap. - For example, `at::zeros(3, 5)` could return a Tensor of shape `(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1. This requires maintaining some global and/or thread-local state about the size of the dims being vmapped over which can be tricky. And more... Future ------ - I will undo some of the work I’ve done in the past to move backward functions to being operators (#44052, #44408). The simpler backward functions (like select backward) can just use Tensor.new_zeros. I apologize for the thrashing. - Include a NOTE about the vmap+inplace problem somewhere in the codebase. I don't have a good idea of where to put it at the moment. Test Plan --------- - New tests ghstack-source-id: caeed93 Pull Request resolved: #46674
💊 CI failures summary and remediationsAs of commit cbf8760 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 3 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
} else { | ||
auto mask = value.isnan().item<bool>() ? input.isnan() : input == value; | ||
return at::zeros_like(input).masked_fill_(mask, grad / mask.sum()); | ||
return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense that the batch info are taken from grad and the other sizes from input.
I think it is worth mentioning in the "vmap gotcha" (if you have that) that the new_*
functions behave this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't a vmap gotcha section anywhere, but I'll make a note on that in the future
Summary ------- This adds batched gradient support (i.e., vmap through the gradient formulas) for Tensor.max(), Tensor.min(), Tensor.median() that have evenly_distribute_backward as their backward formula. Previously, the plan was to register incompatible gradient formulas as backward operators (see #44052). However, it turns out that we can just use `new_zeros` to get around some incompatible gradient formulas (see next section for discussion). Context: the vmap+inplace problem --------------------------------- A lot of backwards functions are incompatible with BatchedTensor due to using in-place operations. Sometimes we can allow the in-place operations, but other times we can't. For example, consider select_backward: ``` Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.select(dim, index).copy_(grad); return grad_input; } ``` and consider the following code: ``` x = torch.randn(5, requires_grad=True) def select_grad(v): torch.autograd.grad(x[0], x, v) vs = torch.randn(B0) batched_grads = vmap(select_grad)(vs) ``` For the batched gradient use case, grad is a BatchedTensor. The physical version of grad has size (B0,). However, select_backward creates a grad_input of shape (5), and tries to copy grad to a slice of it. Up until now, the proposal to handle this has been to register these backward formulas as operators so that vmap doesn’t actually see the `copy_` calls (see #44052). However, it turns out we can actually just use `new_zeros` to construct a new Tensor that has the same "batched-ness" as grad: ``` auto grad_input = grad.new_zeros(input_sizes); grad_input.select(dim, index).copy_(grad); ``` We should use this for simple backward functions. For more complicated backward functions where this solution doesn't work, we should register those as operators. Alternatives ------------ Option 2: Register `evenly_distribute_backward` as an operator and have the vmap fallback run it in a loop. - This requires more LOC changes. - Furthermore, we'd have to write an efficient batching rule for `evenly_distribute_backward` in the future. - If we use `new_zeros` instead, we don't need to write an efficient batching rule for `evenly_distribute_backward` as long as the constituents of `evenly_distributed_backward` have efficient batching rules. Option 3: Have factory functions perform differently if they are called inside vmap. - For example, `at::zeros(3, 5)` could return a Tensor of shape `(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1. This requires maintaining some global and/or thread-local state about the size of the dims being vmapped over which can be tricky. And more... Future ------ - I will undo some of the work I’ve done in the past to move backward functions to being operators (#44052, #44408). The simpler backward functions (like select backward) can just use Tensor.new_zeros. I apologize for the thrashing. - Include a NOTE about the vmap+inplace problem somewhere in the codebase. I don't have a good idea of where to put it at the moment. Test Plan --------- - New tests Differential Revision: [D24456781](https://our.internmc.facebook.com/intern/diff/D24456781) [ghstack-poisoned]
Summary ------- This adds batched gradient support (i.e., vmap through the gradient formulas) for Tensor.max(), Tensor.min(), Tensor.median() that have evenly_distribute_backward as their backward formula. Previously, the plan was to register incompatible gradient formulas as backward operators (see #44052). However, it turns out that we can just use `new_zeros` to get around some incompatible gradient formulas (see next section for discussion). Context: the vmap+inplace problem --------------------------------- A lot of backwards functions are incompatible with BatchedTensor due to using in-place operations. Sometimes we can allow the in-place operations, but other times we can't. For example, consider select_backward: ``` Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.select(dim, index).copy_(grad); return grad_input; } ``` and consider the following code: ``` x = torch.randn(5, requires_grad=True) def select_grad(v): torch.autograd.grad(x[0], x, v) vs = torch.randn(B0) batched_grads = vmap(select_grad)(vs) ``` For the batched gradient use case, grad is a BatchedTensor. The physical version of grad has size (B0,). However, select_backward creates a grad_input of shape (5), and tries to copy grad to a slice of it. Up until now, the proposal to handle this has been to register these backward formulas as operators so that vmap doesn’t actually see the `copy_` calls (see #44052). However, it turns out we can actually just use `new_zeros` to construct a new Tensor that has the same "batched-ness" as grad: ``` auto grad_input = grad.new_zeros(input_sizes); grad_input.select(dim, index).copy_(grad); ``` We should use this for simple backward functions. For more complicated backward functions where this solution doesn't work, we should register those as operators. Alternatives ------------ Option 2: Register `evenly_distribute_backward` as an operator and have the vmap fallback run it in a loop. - This requires more LOC changes. - Furthermore, we'd have to write an efficient batching rule for `evenly_distribute_backward` in the future. - If we use `new_zeros` instead, we don't need to write an efficient batching rule for `evenly_distribute_backward` as long as the constituents of `evenly_distributed_backward` have efficient batching rules. Option 3: Have factory functions perform differently if they are called inside vmap. - For example, `at::zeros(3, 5)` could return a Tensor of shape `(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1. This requires maintaining some global and/or thread-local state about the size of the dims being vmapped over which can be tricky. And more... Future ------ - I will undo some of the work I’ve done in the past to move backward functions to being operators (#44052, #44408). The simpler backward functions (like select backward) can just use Tensor.new_zeros. I apologize for the thrashing. - Include a NOTE about the vmap+inplace problem somewhere in the codebase. I don't have a good idea of where to put it at the moment. Test Plan --------- - New tests ghstack-source-id: 5cf5f04 Pull Request resolved: #46674
Codecov Report
@@ Coverage Diff @@
## gh/zou3519/316/base #46674 +/- ##
====================================================
Coverage 68.98% 68.98%
====================================================
Files 433 433
Lines 55921 55921
====================================================
+ Hits 38578 38579 +1
+ Misses 17343 17342 -1 |
Stack from ghstack:
Summary
This adds batched gradient support (i.e., vmap through the gradient
formulas) for Tensor.max(), Tensor.min(), Tensor.median()
that have evenly_distribute_backward as their backward formula.
Previously, the plan was to register incompatible gradient formulas as
backward operators (see #44052). However, it turns out that we can just use
new_zeros
to get around some incompatible gradient formulas (see nextsection for discussion).
Context: the vmap+inplace problem
A lot of backwards functions are incompatible with BatchedTensor due to
using in-place operations. Sometimes we can allow the in-place
operations, but other times we can't. For example, consider select_backward:
and consider the following code:
For the batched gradient use case, grad is a BatchedTensor.
The physical version of grad has size (B0,).
However, select_backward creates a grad_input of shape (5), and
tries to copy grad to a slice of it.
Up until now, the proposal to handle this has been to register these
backward formulas as operators so that vmap doesn’t actually see the
copy_
calls (see #44052). However, it turns out we can actually justuse
new_zeros
to construct a new Tensor that has the same"batched-ness" as grad:
We should use this for simple backward functions. For more complicated
backward functions where this solution doesn't work, we should register
those as operators.
Alternatives
Option 2: Register
evenly_distribute_backward
as an operator and have thevmap fallback run it in a loop.
evenly_distribute_backward
in the future.new_zeros
instead, we don't need to write an efficientbatching rule for
evenly_distribute_backward
as long as theconstituents of
evenly_distributed_backward
have efficient batching rules.Option 3: Have factory functions perform differently if they are called
inside vmap.
at::zeros(3, 5)
could return a Tensor of shape(B0, B1, 3, 5)
if we are vmapping over two dimensions with size B0 and B1.This requires maintaining some global and/or thread-local state about
the size of the dims being vmapped over which can be tricky.
And more...
Future
functions to being operators (Register some backwards functions as operators #44052, Add trace_backward, masked_select_backward, and take_backward as ops #44408). The simpler backward
functions (like select backward) can just use Tensor.new_zeros.
I apologize for the thrashing.
codebase. I don't have a good idea of where to put it at the moment.
Test Plan
Differential Revision: D24456781