-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
from typing import Optional, Tuple
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn as nn
import torch_tensorrt
class SimpleNetwork(nn.Module):
def __init__(self):
super().__init__()
def forward(self, padding_mask, hidden_states):
batch_size, num_channels, num_frames, height, width = hidden_states.shape
# target_shape = (batch_size, 1, num_frames, height, width)
# hidden_states = torch.cat(
# [hidden_states, padding_mask.unsqueeze(2).expand(target_shape)], dim=1
# )
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)
return hidden_states
def export_attention(model, padding_mask, hidden_states):
with torch.no_grad():
# Only mark sequence length as dynamic, like run_llm.py does
# Don't mark batch dimension as dynamic to avoid constraint violations
seq_len = torch.export.Dim("seq_len", min=1, max=16)
print("Trying to export the model using torch.export.export()..")
# strict=False only enables autograd tracing and excludes dynamo.
# Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
ep = torch.export.export(
model,
args=(padding_mask, hidden_states),
kwargs={},
dynamic_shapes=({}, {2: seq_len}),
strict=False,
)
return ep
def compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug):
ep = export_attention(model, padding_mask, hidden_states)
# Set precision specific flags
use_fp32_acc = False
use_explicit_typing = False
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
with torch_tensorrt.logging.debug() if debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[padding_mask, hidden_states],
enabled_precisions=enabled_precisions,
# truncate_double=True,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
disable_tf32=True,
use_python_runtime=True,
debug=debug,
offload_module_to_cpu=False,
min_block_size=min_block_size,
)
return trt_model
if __name__ == "__main__":
min_block_size = 1
enable_pytorch_run = True
debug = False
device = "cuda"
with torch.inference_mode():
model = SimpleNetwork().to(device)
# Convert model to the appropriate precision
model = model.to(torch.bfloat16)
input_dtype = torch.bfloat16
# Prepare input for benchmarking or evaluation
padding_mask = torch.randn(
1, 1, 88, 160, dtype=input_dtype
).to(device)
hidden_states = torch.randn(
1, 17, 8, 88, 160, dtype=input_dtype
).to(device)
# Pyt
pyt_output = model(padding_mask, hidden_states)
print("PyTorch output shape:", pyt_output.shape)
print("Pytorch output:", pyt_output.flatten())
# Compile the model with Torch-TensorRT
trt_model = compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug)
# trt_model = torch.compile(
# model,
# backend="torch_tensorrt",
# options={
# "enabled_precisions": {input_dtype},
# "use_python_runtime": True,
# "min_block_size": min_block_size,
# },
# dynamic=None,
# )
trt_model = trt_model.to(device)
trt_output = trt_model(padding_mask, hidden_states)
print("TensorRT output shape:", trt_output.shape)
print("TensorRT output:", trt_output.flatten())
# Verify results match
diff = (pyt_output - trt_output).abs().max().item()
print(f"Max difference between PyTorch and TRT: {diff}")
# Check if results are close enough
tolerance = 0.01
if diff < tolerance:
print(f"✅ Results match! (difference: {diff} < {tolerance})")
else:
print(f"⚠️ Results differ! (difference: {diff} >= {tolerance})")
Here is the error message:
Traceback (most recent call last):
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 235, in <module>
trt_model = compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 170, in compile_torchtrt
trt_model = torch_tensorrt.dynamo.compile(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 782, in compile
trt_gm = compile_module(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 1028, in compile_module
trt_module = convert_module(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 145, in convert_module
serialized_interpreter_result = interpret_module_to_result(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 721, in run
self._construct_trt_network_def()
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 433, in _construct_trt_network_def
super().run()
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 790, in run_node
trt_node: torch.fx.Node = super().run_node(n)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 256, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 897, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 677, in convert_with_type_enforcement
return func(ctx, target, new_args, new_kwargs, name)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1236, in aten_ops_expand
return impl.slice.expand(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 231, in expand
input_t = prepend_ones(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 1082, in prepend_ones
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
TypeError: (): incompatible function arguments. The following argument types are supported:
1. (arg0: tensorrt.tensorrt.IShuffleLayer, arg1: tensorrt.tensorrt.Dims) -> None
Invoked with: <tensorrt.tensorrt.IShuffleLayer object at 0x7a36c40701b0>, (1, 1, 1, 1, 1, 1, 1, 1, 88, 160)
While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze, [1, 1, %sym_size_int_2, 1, 1, 1, 1, 1, 88, 160]), kwargs = {})
Original traceback:
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 137, in forward
[hidden_states, padding_mask.unsqueeze(2).repeat(1, 1, num_frames, 1, 1)], dim=1
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
To Reproduce
Steps to reproduce the behavior:
- Run Python script above
Expected behavior
Passed and the Torch-TRT output matches the Torch output.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
- PyTorch Version (e.g. 1.0): 2.9.0
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
- Are you using local sources or building from archives: No
- Python version: 3.10
- CUDA version: 12.9
- GPU models and configuration: Nvidia B200
- Any other relevant information: None
Additional context
The WAR is to use the following code instead of repeat.
target_shape = (batch_size, 1, num_frames, height, width)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).expand(target_shape)], dim=1
)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working