Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autograd.Function] add nice error message for incorrect usage of vmap #92023

Closed
wants to merge 6 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Jan 11, 2023

Stack from ghstack:

This PR:

  • adds a nice error message if the user doesn't follow the API of the
    vmap staticmethod correctly. That is, the user must return two
    arguments from the vmap staticmethod API: (outputs, out_dims), and
    out_dims must be a PyTree with either the same structure as outputs
    our be broadcastable to the same structure as outputs.
  • Fixes an edge case for out_dims=None. out_dims is allowed to be None,
    but wrap_outputs_maintaining_identity was treating "None" as "This is
    not the vmap case"

Test Plan:

  • new tests

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92023

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6b7f0fa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…sage of vmap"

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

[ghstack-poisoned]
…sage of vmap"

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

[ghstack-poisoned]
@@ -267,16 +291,14 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands):
return custom_function_call(autograd_function, *operands)

with interpreter.lower():
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands)
result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
validate_vmap_returns_two(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we might as well fully describe it here, then a reader won't even have to read the function to figure out what it does - "validate_vmap_returns_tuple_of_two_elements".

@staticmethod
def vmap(info, in_dims, input):
assert in_dims == (0,)
return torch.zeros(input.shape[1:], device=input.device), None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this work? Does _broadcast_to_and_flatten do something special if out_dims is None? Otherwise it seems that the wrap_fn we define should just not wrap the output in a batched tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your reading is correct -- _broadcast_to_and_flatten doesn't have any special logic for None. It is intended that vmap does not construct a BatchedTensor in this case.

The expected behavior of this is that the output (call it zeros) is not wrapped in a BatchedTensor. That's OK. Inside of vmap, a Tensor is either a BatchedTensor or a regular Tensor. This is the latter case.

There are two cases:

  • Eventually the regular Tensor interacts with a BatchedTensor. E.g. regular_tensor * batched_tensor. This produces a BatchedTensor.
  • The regular Tensor gets returned from the function being vmapped over. vmap has special handling to take this Tensor and expand out the batch dimension.

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok, just have a question about what happens in None case.

@@ -145,11 +145,14 @@ def jvp(ctx, *tangents):
)
return Generated

NO_OUT_DIMS = "not specified"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid confusing None with "not specified" right? Maybe we could have a comment about what happens in the None case?

…sage of vmap"

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

[ghstack-poisoned]
@github-actions github-actions bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Jan 17, 2023
@yanbing-j
Copy link
Collaborator

Hi @zou3519 , do you need to make some changes in third_party/ideep? If not, perhaps you can sync the submodule to get the latest third_party.

…sage of vmap"

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

cc gujinghui @PenghuiCheng @XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j @Guobing-Chen @Xia-Weiwen

[ghstack-poisoned]
@yanbing-j yanbing-j removed the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Jan 17, 2023
…sage of vmap"

This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

cc gujinghui @PenghuiCheng @XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j @Guobing-Chen @Xia-Weiwen

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants