-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Open
Labels
actionablemodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
When I create a nested tensor in jagged format and try to use it with MHA, it throws
AssertionError: MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because some Tensor argument has_torch_function
Reproduce with:
import torch
N=512
nheads=8
query = torch.nested.nested_tensor([
torch.randn(100, N, device=device)
for _ in range(2)
], layout=torch.jagged)
mha = torch.nn.MultiheadAttention(N,nheads,batch_first=True).eval()
mha(query,query,query)
It works with strided nested tensors, though, but the documentation says that jagged tensors should be supported and the underlying scaled_dot_product_attention
does support it.
Versions
I am on python 3.12 and torch 2.8.0a0(+fb).
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
vadimkantorov and HansBambel
Metadata
Metadata
Assignees
Labels
actionablemodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module