Skip to content

🐛 [Bug] add_attention fails with dynamic head dimension when exporting GQA models via torch.export #4145

@chohk88

Description

@chohk88

Bug Description

When exporting HuggingFace GQA models (e.g., Qwen2.5-0.5B with Q=14 heads, KV=2 heads) using torch.export with dynamic seq_len, the query tensor's head dimension becomes dynamic (-1) after .view().transpose(). Key/value tensors preserve their static head dimension because they go through repeat_kv (.expand().reshape()).

TRT's add_attention requires static head dimensions and fails at network construction:

Error Code 3: API Usage Error ((Unnamed Layer*) [AttentionInput]:
query number of heads must be divisible by key/value number of heads)

Observed shapes at the scaled_dot_product_efficient_attention converter:

scaled_query: (1, -1, -1, 64)   # head dim is dynamic
key:          (1, 14, -1, 64)   # head dim is static

To Reproduce

import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", use_cache=False, attn_implementation="sdpa"
).eval().cuda().to(torch.float16)

input_ids = torch.randint(1, 10000, (1, 64), dtype=torch.int64, device="cuda")
position_ids = torch.arange(64).unsqueeze(0).cuda()

seq_len = torch.export.Dim("seq_len", min=1, max=128)
ep = torch.export._trace._export(
    model,
    args=(input_ids,),
    kwargs={"position_ids": position_ids},
    dynamic_shapes=({1: seq_len}, {1: seq_len}),
    strict=False,
    prefer_deferred_runtime_asserts_over_guards=True,
)

# This fails with "query number of heads must be divisible by key/value number of heads"
trt_model = torch_tensorrt.dynamo.compile(
    ep,
    inputs=[input_ids, position_ids],
    enabled_precisions={torch.float32},
    use_explicit_typing=True,
    use_fp32_acc=True,
    device=torch.device("cuda:0"),
    min_block_size=1,
)

or

"""
Minimal reproduction for dynamic head dimension issue with IAttention.

When torch.export traces HuggingFace models with dynamic seq_len,
the query tensor's head dimension becomes dynamic (-1) after
view().transpose(), while key/value preserve static head dim through
repeat_kv's expand().reshape().

TRT's add_attention requires static head dimensions and fails with:
  Error Code 3: API Usage Error ((Unnamed Layer*) [AttentionInput]:
  query number of heads must be divisible by key/value number of heads)

This test uses the actual HuggingFace Qwen2.5-0.5B model to reproduce
the exact tracing behavior.
"""

import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM


def test_dynamic_head_dim_with_hf_model():
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"

    model = (
        AutoModelForCausalLM.from_pretrained(
            model_name, use_cache=False, attn_implementation="sdpa"
        )
        .eval()
        .cuda()
        .to(torch.float16)
    )

    input_ids = torch.randint(1, 10000, (1, 64), dtype=torch.int64, device="cuda")
    position_ids = torch.arange(64).unsqueeze(0).cuda()

    seq_len = torch.export.Dim("seq_len", min=1, max=128)
    try:
        ep = torch.export.export(
            model,
            args=(input_ids,),
            kwargs={"position_ids": position_ids},
            dynamic_shapes=({1: seq_len}, {1: seq_len}),
            strict=False,
        )
    except Exception:
        ep = torch.export._trace._export(
            model,
            args=(input_ids,),
            kwargs={"position_ids": position_ids},
            dynamic_shapes=({1: seq_len}, {1: seq_len}),
            strict=False,
            prefer_deferred_runtime_asserts_over_guards=True,
        )

    trt_model = torch_tensorrt.dynamo.compile(
        ep,
        inputs=[input_ids, position_ids],
        enabled_precisions={torch.float32},
        use_explicit_typing=True,
        use_fp32_acc=True,
        device=torch.device("cuda:0"),
        min_block_size=1,
    )

    with torch.no_grad():
        ref = model(input_ids, position_ids=position_ids).logits
        out = trt_model(input_ids, position_ids)
        # TRT model may return CausalLMOutputWithPast or tuple
        if hasattr(out, "logits"):
            out = out.logits
        elif isinstance(out, (tuple, list)):
            out = out[0]

    torch.testing.assert_close(ref, out, rtol=1e-1, atol=2e-1)
    print("PASSED: dynamic head dim test with HF model")


if __name__ == "__main__":
    test_dynamic_head_dim_with_hf_model()

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions