-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[autograd.Function] add nice error message for incorrect usage of vmap #92023
Conversation
This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92023
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6b7f0fa: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…sage of vmap" This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests [ghstack-poisoned]
…sage of vmap" This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests [ghstack-poisoned]
@@ -267,16 +291,14 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands): | |||
return custom_function_call(autograd_function, *operands) | |||
|
|||
with interpreter.lower(): | |||
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands) | |||
result = autograd_function.vmap(info, in_dims, *unwrapped_operands) | |||
validate_vmap_returns_two(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we might as well fully describe it here, then a reader won't even have to read the function to figure out what it does - "validate_vmap_returns_tuple_of_two_elements".
@staticmethod | ||
def vmap(info, in_dims, input): | ||
assert in_dims == (0,) | ||
return torch.zeros(input.shape[1:], device=input.device), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this work? Does _broadcast_to_and_flatten do something special if out_dims is None? Otherwise it seems that the wrap_fn we define should just not wrap the output in a batched tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your reading is correct -- _broadcast_to_and_flatten
doesn't have any special logic for None. It is intended that vmap does not construct a BatchedTensor in this case.
The expected behavior of this is that the output (call it zeros
) is not wrapped in a BatchedTensor. That's OK. Inside of vmap
, a Tensor is either a BatchedTensor or a regular Tensor. This is the latter case.
There are two cases:
- Eventually the regular Tensor interacts with a BatchedTensor. E.g.
regular_tensor * batched_tensor
. This produces a BatchedTensor. - The regular Tensor gets returned from the function being vmapped over. vmap has special handling to take this Tensor and expand out the batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks ok, just have a question about what happens in None case.
@@ -145,11 +145,14 @@ def jvp(ctx, *tangents): | |||
) | |||
return Generated | |||
|
|||
NO_OUT_DIMS = "not specified" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to avoid confusing None with "not specified" right? Maybe we could have a comment about what happens in the None case?
…sage of vmap" This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests [ghstack-poisoned]
Hi @zou3519 , do you need to make some changes in |
…sage of vmap" This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests cc gujinghui @PenghuiCheng @XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j @Guobing-Chen @Xia-Weiwen [ghstack-poisoned]
…sage of vmap" This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests cc gujinghui @PenghuiCheng @XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j @Guobing-Chen @Xia-Weiwen [ghstack-poisoned]
Stack from ghstack:
This PR:
vmap staticmethod correctly. That is, the user must return two
arguments from the vmap staticmethod API: (outputs, out_dims), and
out_dims must be a PyTree with either the same structure as
outputs
our be broadcastable to the same structure as
outputs
.but wrap_outputs_maintaining_identity was treating "None" as "This is
not the vmap case"
Test Plan:
cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen