Skip to content

🐛 [Bug] Shape error with repeat #3974

@zhaoyuanh

Description

@zhaoyuanh

Bug Description

from typing import Optional, Tuple
from contextlib import nullcontext

import numpy as np
import torch
import torch.nn as nn
import torch_tensorrt


class SimpleNetwork(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, padding_mask, hidden_states):
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        # target_shape = (batch_size, 1, num_frames, height, width)
        # hidden_states = torch.cat(
        #     [hidden_states,  padding_mask.unsqueeze(2).expand(target_shape)], dim=1
        # )
        hidden_states = torch.cat(
            [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
        )
        return hidden_states


def export_attention(model, padding_mask, hidden_states):
    with torch.no_grad():
        # Only mark sequence length as dynamic, like run_llm.py does
        # Don't mark batch dimension as dynamic to avoid constraint violations
        seq_len = torch.export.Dim("seq_len", min=1, max=16)
        print("Trying to export the model using torch.export.export()..")
        # strict=False only enables autograd tracing and excludes dynamo.
        # Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
        ep = torch.export.export(
            model,
            args=(padding_mask, hidden_states),
            kwargs={},
            dynamic_shapes=({}, {2: seq_len}), 
            strict=False,
        )

    return ep


def compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug):
    ep = export_attention(model, padding_mask, hidden_states)
    # Set precision specific flags
    use_fp32_acc = False
    use_explicit_typing = False
    enabled_precisions = {torch.bfloat16}
    use_fp32_acc = False

    with torch_tensorrt.logging.debug() if debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[padding_mask, hidden_states],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
            use_explicit_typing=use_explicit_typing,
            use_fp32_acc=use_fp32_acc,
            disable_tf32=True,
            use_python_runtime=True,
            debug=debug,
            offload_module_to_cpu=False,
            min_block_size=min_block_size,
        )

    return trt_model


if __name__ == "__main__":
    min_block_size = 1
    enable_pytorch_run = True
    debug = False
    device = "cuda"

    with torch.inference_mode():
        model = SimpleNetwork().to(device)

        # Convert model to the appropriate precision
        model = model.to(torch.bfloat16)
        input_dtype = torch.bfloat16

        # Prepare input for benchmarking or evaluation
        padding_mask = torch.randn(
            1, 1, 88, 160, dtype=input_dtype
        ).to(device)
        hidden_states = torch.randn(
            1, 17, 8, 88, 160, dtype=input_dtype
        ).to(device)

        # Pyt
        pyt_output = model(padding_mask, hidden_states)
        print("PyTorch output shape:", pyt_output.shape)
        print("Pytorch output:", pyt_output.flatten())

        # Compile the model with Torch-TensorRT
        trt_model = compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug)
        # trt_model = torch.compile(
        #     model,
        #     backend="torch_tensorrt",
        #     options={
        #         "enabled_precisions": {input_dtype},
        #         "use_python_runtime": True,
        #         "min_block_size": min_block_size,
        #     },
        #     dynamic=None,
        # )
        trt_model = trt_model.to(device)

        trt_output = trt_model(padding_mask, hidden_states)
        print("TensorRT output shape:", trt_output.shape)
        print("TensorRT output:", trt_output.flatten())
    
    # Verify results match
    diff = (pyt_output - trt_output).abs().max().item()
    print(f"Max difference between PyTorch and TRT: {diff}")

    # Check if results are close enough
    tolerance = 0.01
    if diff < tolerance:
        print(f"✅ Results match! (difference: {diff} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff} >= {tolerance})")

Here is the error message:

Traceback (most recent call last):
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 235, in <module>
    trt_model = compile_torchtrt(model, padding_mask, hidden_states, min_block_size, debug)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 170, in compile_torchtrt
    trt_model = torch_tensorrt.dynamo.compile(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 782, in compile
    trt_gm = compile_module(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 1028, in compile_module
    trt_module = convert_module(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 145, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 721, in run
    self._construct_trt_network_def()
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 433, in _construct_trt_network_def
    super().run()
  File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 790, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 897, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 677, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1236, in aten_ops_expand
    return impl.slice.expand(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 231, in expand
    input_t = prepend_ones(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 1082, in prepend_ones
    layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: tensorrt.tensorrt.IShuffleLayer, arg1: tensorrt.tensorrt.Dims) -> None

Invoked with: <tensorrt.tensorrt.IShuffleLayer object at 0x7a36c40701b0>, (1, 1, 1, 1, 1, 1, 1, 1, 88, 160)

While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze, [1, 1, %sym_size_int_2, 1, 1, 1, 1, 1, 88, 160]), kwargs = {})
Original traceback:
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 137, in forward
    [hidden_states, padding_mask.unsqueeze(2).repeat(1, 1, num_frames, 1, 1)], dim=1
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

To Reproduce

Steps to reproduce the behavior:

  1. Run Python script above

Expected behavior

Passed and the Torch-TRT output matches the Torch output.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
  • PyTorch Version (e.g. 1.0): 2.9.0
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
  • Are you using local sources or building from archives: No
  • Python version: 3.10
  • CUDA version: 12.9
  • GPU models and configuration: Nvidia B200
  • Any other relevant information: None

Additional context

The WAR is to use the following code instead of repeat.

target_shape = (batch_size, 1, num_frames, height, width)
hidden_states = torch.cat(
        [hidden_states,  padding_mask.unsqueeze(2).expand(target_shape)], dim=1
)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions