-
Notifications
You must be signed in to change notification settings - Fork 385
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When exporting HuggingFace GQA models (e.g., Qwen2.5-0.5B with Q=14 heads, KV=2 heads) using torch.export with dynamic seq_len, the query tensor's head dimension becomes dynamic (-1) after .view().transpose(). Key/value tensors preserve their static head dimension because they go through repeat_kv (.expand().reshape()).
TRT's add_attention requires static head dimensions and fails at network construction:
Error Code 3: API Usage Error ((Unnamed Layer*) [AttentionInput]:
query number of heads must be divisible by key/value number of heads)
Observed shapes at the scaled_dot_product_efficient_attention converter:
scaled_query: (1, -1, -1, 64) # head dim is dynamic
key: (1, 14, -1, 64) # head dim is static
To Reproduce
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", use_cache=False, attn_implementation="sdpa"
).eval().cuda().to(torch.float16)
input_ids = torch.randint(1, 10000, (1, 64), dtype=torch.int64, device="cuda")
position_ids = torch.arange(64).unsqueeze(0).cuda()
seq_len = torch.export.Dim("seq_len", min=1, max=128)
ep = torch.export._trace._export(
model,
args=(input_ids,),
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)
# This fails with "query number of heads must be divisible by key/value number of heads"
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids, position_ids],
enabled_precisions={torch.float32},
use_explicit_typing=True,
use_fp32_acc=True,
device=torch.device("cuda:0"),
min_block_size=1,
)or
"""
Minimal reproduction for dynamic head dimension issue with IAttention.
When torch.export traces HuggingFace models with dynamic seq_len,
the query tensor's head dimension becomes dynamic (-1) after
view().transpose(), while key/value preserve static head dim through
repeat_kv's expand().reshape().
TRT's add_attention requires static head dimensions and fails with:
Error Code 3: API Usage Error ((Unnamed Layer*) [AttentionInput]:
query number of heads must be divisible by key/value number of heads)
This test uses the actual HuggingFace Qwen2.5-0.5B model to reproduce
the exact tracing behavior.
"""
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM
def test_dynamic_head_dim_with_hf_model():
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = (
AutoModelForCausalLM.from_pretrained(
model_name, use_cache=False, attn_implementation="sdpa"
)
.eval()
.cuda()
.to(torch.float16)
)
input_ids = torch.randint(1, 10000, (1, 64), dtype=torch.int64, device="cuda")
position_ids = torch.arange(64).unsqueeze(0).cuda()
seq_len = torch.export.Dim("seq_len", min=1, max=128)
try:
ep = torch.export.export(
model,
args=(input_ids,),
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
)
except Exception:
ep = torch.export._trace._export(
model,
args=(input_ids,),
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids, position_ids],
enabled_precisions={torch.float32},
use_explicit_typing=True,
use_fp32_acc=True,
device=torch.device("cuda:0"),
min_block_size=1,
)
with torch.no_grad():
ref = model(input_ids, position_ids=position_ids).logits
out = trt_model(input_ids, position_ids)
# TRT model may return CausalLMOutputWithPast or tuple
if hasattr(out, "logits"):
out = out.logits
elif isinstance(out, (tuple, list)):
out = out[0]
torch.testing.assert_close(ref, out, rtol=1e-1, atol=2e-1)
print("PASSED: dynamic head dim test with HF model")
if __name__ == "__main__":
test_dynamic_head_dim_with_hf_model()Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working