-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
autograd.Function supports vmap staticmethod
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. ghstack-source-id: b17a03e7563a663418f30f99f6e21366c9d62015 Pull Request resolved: #90037
- Loading branch information
Showing
12 changed files
with
297 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.