-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
autograd.Function supports vmap staticmethod #90037
Conversation
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) @staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) @staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy @staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90037
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 0826dca: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. ghstack-source-id: 262d563df39783129f9abc083ccb651fa692c7aa Pull Request resolved: #90037
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. ghstack-source-id: b980bc3549607f734c8883355c884003e7eb6eb6 Pull Request resolved: #90037
return materializeGradWrappers(tensor, level()); | ||
} | ||
|
||
Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const { | ||
return base_lift(tensor, level()); | ||
return materializeGradWrappers(tensor, level()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes undone because it was better to unwrap dead wrappers before any interpreters run (see dispatch_functorch
).
|
||
def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): | ||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||
# Broadcasting | ||
yield SampleInput(make_arg(2, low=0.9, high=2), args=(make_arg(3, 2, low=0.9, high=2),)) | ||
yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sizes adjusted because we add extra dimensions of size 2 in vmap testing, so it is easier to debug when the sizes (before being pass through the vmap tests) don't have dimensions of size 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a vmap noob who hasn't debugged many batching rules, I don't understand this lol, could you elaborate a little?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup!
Let's say I incorrectly wrote the mul batching rule as:
def vmap(info, in_dims, x, y):
return x * y
and then we tried:
x = torch.randn(2, 2)
y = torch.randn(2, 2, 2)
z = vmap(NumpyMul.apply, in_dims(0, 0))(x, y)
This would succeed and also be wrong -- we would end up with z = x * y, but the result should be z = x.unsqueeze(1) * y (we need to line up dimension 0 of x with dimension 0 of y). The test runner would fail with something like "output values are incorrect".
If the dimensions being vmapped over for x and y have a unique shape from the rest of the dimensions, then we get a loud error:
x = torch.randn(2, 4)
y = torch.randn(2, 3, 4)
z = vmap(NumpyMul.apply, in_dims(0, 0))(x, y)
# Error: can't broadcast Tensor of shape [2, 4] with Tensor of shape [2, 3, 4]
and with this error we can tell that we're doing something wrong (we forgot to change x
into having shape [2, 1, 4] before multiplying it with y)
Now, how does this apply to the samples in the OpInfo?
- vmap testing takes the samples in the OpInfo and expands them to include an additional dimension of size 2.
- Then it just does vmap(op)(*expanded_inputs).
- Ideally the original samples in the OpInfo do not have dimensions of size 2 so we get a loud shape error instead of a "values are incorrect" error -- the shape error tells us we misaligned dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks!
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. ghstack-source-id: b17a03e7563a663418f30f99f6e21366c9d62015 Pull Request resolved: #90037
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
# https://github.com/pytorch/pytorch/issues/90224 | ||
flat_args, spec = pytree.tree_flatten(args) | ||
flat_bdims = _broadcast_to_and_flatten(bdims, spec) | ||
assert flat_bdims is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_broadcast_to_and_flatten
returns None if there is an error (if it's impossible to broadcast bdims
to spec
). I'm planning on writing better error messages in a followup
# but that one is hyperspecialized on error messages. | ||
# https://github.com/pytorch/pytorch/issues/90224 | ||
flat_args, spec = pytree.tree_flatten(args) | ||
flat_bdims = _broadcast_to_and_flatten(bdims, spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check, this means that if a user's return looks like return (a, b), 0
, this will imply that a and b are both batched at index 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It's the same rules as vmap's in_dims arg: vmap(f, in_dims=0)(a, b)
implies both a and b are batched at index 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a couple questions on the UX for returning batch dims
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands) | ||
|
||
output = wrap_batched(unwrapped_output, out_dims, current_level) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more of a pyfunctorch question, but is there a reason this plumbing cannot be done at the pyfunctorch level? If that is in place, all we need to do here is call autograd_function.vmap
.
Analogously for autograd, interpreter.lower() would handling the boilerplate of unwrapping and unwrapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more of a pyfunctorch question, but is there a reason this plumbing cannot be done at the pyfunctorch level? If that is in place, all we need to do here is call autograd_function.vmap.
That's a good question. I can see it going either way.
In C++, there are two types of batching rules that people write:
(1) unwrap(), lower(), do_something(), wrap()
(2) call a sequence of PyTorch operations
Most of our C++ batching rules are written as the former. CompositeImplicitAutograd and some other things are written as the latter (the batching rule for cross_entropy_loss calls log_softmax and nll_loss without unwrapping and lowering) There are a non-zero amount of rules that are a hybrid between the two (yikes).
pyfunctorch is supposed to be consistent with our C++ dispatcher, so it should in theory let people (developers) do either. If we did this plumbing at the pydispatcher level, this would force (1) onto every PyOperator, right? For autograd.Function with vmap staticmethod, we've made the decision that we are forcing (1) onto people for simplicity (so they don't need to manually call interpreter.lower() and understand what an interpreter is).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, I didn't think about the case where we want to register a composite. But one workaround could be: if we force (1) on people, maybe we could replace that functionality with an API specifically for registering decompositions for layers.
Sounds like this might prevent people from doing the hybrid approach, but maybe we want to discourage those anyway, and this would actually be the way of enforcing that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hybrid approach isn't bad if there's a good reason for it (which there is in the C++ code -- it's difficult to write a batching rule when your operation does broadcasting).
For this PR I'll land it as-is, but we can revisit the situation in the future.
Analogously for autograd, interpreter.lower() would handling the boilerplate of unwrapping and unwrapping.
To check my understanding of this comment -- would one no longer have an interpreter.lower() inside the custom_function_call_grad function if we gave the same treatment to the grad transform?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check my understanding of this comment -- would one no longer have an interpreter.lower() inside the custom_function_call_grad function if we gave the same treatment to the grad transform?
Oh what I meant was: In c++ functorch, the "send to next" of autograd interpreter handles unwrapping (which is pop layer off stack + unwrap), but in pyfunctorch the "send to next" for autograd interpreter is just lower (just pop layer of stack), so that is why user has to manually unwrap. I don't know if we want to do anything about that though.
@@ -675,6 +676,7 @@ def fn(inp, *args, **kwargs): | |||
xfail("double"), # rank 4 tensor for channels_last | |||
xfail("float"), # rank 4 tensor for channels_last | |||
xfail("half"), # rank 4 tensor for channels_last | |||
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like if we implemented vmap for this one, it would be composable in at least one way: grad(vmap. vmap(grad probably won't work still.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(just a comment, not really actionable unless you also wanted to test this case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't actually know how to implement vmap for this one (and this is intentional -- I don't think users should be trying to implement vmap support for it). NumpyCubeNotComposableAutogradFunction returns a (tensor, ndarray). One is only allowed to specify bdims (dimension being vmapped over) for Tensors (and the only return from vmap can be Tensors), but we have an ndarray here
} | ||
return std::make_tuple(tensor, nullopt); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from the get_unwrapped function that we used previously? Do we want to update those uses to use unwrapBatched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about functorch._C.get_unwrapped? Given a {Tensor, BatchedTensor, TensorWrapper, FunctionalTensorWrapper}, get_unwrapped returns the underlying Tensor (if applicable).
Given a {Tensor, BatchedTensor} and an integer level, unwrapBatched returns the underlying Tensor (if applicable) and the dimension it is being vmapped over.
|
||
def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): | ||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||
# Broadcasting | ||
yield SampleInput(make_arg(2, low=0.9, high=2), args=(make_arg(3, 2, low=0.9, high=2),)) | ||
yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a vmap noob who hasn't debugged many batching rules, I don't understand this lol, could you elaborate a little?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just had a couple small questions
This PR adds a `vmap` staticmethod to autograd.Function and a corresponding vmap kernel for custom_function_call. These two items mean that autograd.Function with a vmap staticmethod can be used with vmap. ```py class NumpyMul(torch.autograd.Function) staticmethod def forward(x, y): return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) staticmethod def setup_context(ctx, outputs, x, y): ctx.save_for_backward(x, y) staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors gx = None if isinstance(x, torch.Tensor) and x.requires_grad: gx = NumpyMul.apply(grad_output, y) gy = None if isinstance(y, torch.Tensor) and y.requires_grad: gy = NumpyMul.apply(grad_output, x) return gx, gy staticmethod def vmap(info, in_dims, x, y): x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1) result = NumpyMul.apply(x, y) result = result.movedim(-1, 0) return result, 0 ``` API Spec - the staticmethod takes two arguments (info, in_dims) as well as the unexpanded inputs (x, y). - If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a pytree with the same tree structure as args. It has None if the arg is not being vmapped over and an integer vmapped dimension index if it is. - `info` is an object with metadata about the vmap. It currently has one field, `info.batch_size`. In the future we can extend this by adding things like the randomness information. - If there is a single vmap going on, (x, y) are NOT BatchedTensors, they've already been unpacked. - We expect the user to return a `(outputs, out_dims)` tuple. `out_dims` must "broadcast" to the same pytree structure as `outputs`. Semantics - vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is one and will never actually run NumpyMul.forward. - In order for the autograd.Function to support nested vmap (e.g., `vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call into operations that vmap understands (i.e. PyTorch operators or more autograd.Function). At a high level, this PR: - adds a vmap rule for custom_function_call Testing - Added some tests for in_dims and info - Added vmap staticmethod to most of the autograd.Function in autograd_function_db and sent them through functorch's vmap-related OpInfo tests Future - Better error messages if the user gets the return contract wrong. I didn't include them in this PR because it might involve a refactor of some of the existing code in functorch/_src/vmap.py that will add ~200LOC to the PR, but LMK if you'd prefer it here. [ghstack-poisoned]
This pull request has been merged in 3049d99. |
Stack from ghstack:
This PR adds a
vmap
staticmethod to autograd.Function and acorresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.
API Spec
unexpanded inputs (x, y).
vmap(info, in_dims, *args)
,in_dims
is apytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
info
is an object with metadata about the vmap. It currently has onefield,
info.batch_size
. In the future we can extend this by addingthings like the randomness information.
they've already been unpacked.
(outputs, out_dims)
tuple.out_dims
must "broadcast" to the same pytree structure as
outputs
.Semantics
one and will never actually run NumpyMul.forward.
vmap(vmap(NumpyMul.apply))(x)
, then the vmap staticmethod must callinto operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).
At a high level, this PR:
Testing
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests
Future
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.