- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.6k
Closed
Labels
high prioritymodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchmodule: vmaptriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Hi, I am trying to take batched gradients of a vector output given by _scaled_dot_product_efficient_attention but saw the error
/site-packages/optimum/bettertransformer/models/attention.py:56: UserWarning: There is a performance drop because we have not yet implemented the batching rul
e for aten::_scaled_dot_product_efficient_attention. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at ../aten/src/ATen/functorch/BatchedFallback.cpp:82.)
  sdpa_result = torch.nn.functional.scaled_dot_product_attention
when running my code. Implementing this would really increase throughput in our application! I would also be happy to take a stab at implementing it, if there is a document describing what I need to do at a high level.
Alternatives
No response
Additional context
No response
cc @ezyang @gchanan @zou3519 @kadeng @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @Chillee @samdow @kshitij12345 @janeyx99
SamAdamDay, mahdip72 and akirchmeyerproger
Metadata
Metadata
Assignees
Labels
high prioritymodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchmodule: vmaptriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module