Skip to content

Multihead Attention does not work with jagged tensors due to __torch_function__ #153472

@SamuelGabriel

Description

@SamuelGabriel

🐛 Describe the bug

When I create a nested tensor in jagged format and try to use it with MHA, it throws

AssertionError: MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because some Tensor argument has_torch_function

Reproduce with:

import torch

N=512
nheads=8

query = torch.nested.nested_tensor([
        torch.randn(100, N, device=device)
        for _ in range(2)
    ], layout=torch.jagged)

mha = torch.nn.MultiheadAttention(N,nheads,batch_first=True).eval()
mha(query,query,query)

It works with strided nested tensors, though, but the documentation says that jagged tensors should be supported and the underlying scaled_dot_product_attention does support it.

Versions

I am on python 3.12 and torch 2.8.0a0(+fb).

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions