-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
from typing import Optional, Tuple
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn as nn
import torch_tensorrt
class CosmosLearnablePositionalEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
eps: float = 1e-6,
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.eps = eps
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
# Use expand() instead of repeat() - torch_tensorrt compatible
# expand() creates a view without copying data, better for dynamic shapes
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand(batch_size, -1, pe_size[1], pe_size[2], -1)
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand(batch_size, pe_size[0], -1, pe_size[2], -1)
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand(batch_size, pe_size[0], pe_size[1], -1, -1)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
# # Use torch operations instead of np.sqrt to support dynamic shapes in torch.export
# # Compute the scale factor: sqrt(norm.numel() / emb.numel())
# alpha = (norm.numel() / emb.numel()) ** 0.5
# norm = torch.add(self.eps, norm, alpha=alpha)
out = (emb / norm).type_as(hidden_states)
return out
def export_attention(model, hidden_states):
with torch.no_grad():
# Only mark sequence length as dynamic, like run_llm.py does
# Don't mark batch dimension as dynamic to avoid constraint violations
seq_len = torch.export.Dim("seq_len", min=1, max=16)
print("Trying to export the model using torch.export.export()..")
# strict=False only enables autograd tracing and excludes dynamo.
# Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
ep = torch.export.export(
model,
args=(hidden_states,),
kwargs={},
dynamic_shapes=({2: seq_len},),
strict=False,
)
return ep
def compile_torchtrt(model, hidden_states, min_block_size, debug):
ep = export_attention(model, hidden_states)
# Set precision specific flags
use_fp32_acc = False
use_explicit_typing = False
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
with torch_tensorrt.logging.debug() if debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[hidden_states],
enabled_precisions=enabled_precisions,
# truncate_double=True,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
disable_tf32=True,
use_python_runtime=True,
debug=debug,
offload_module_to_cpu=False,
min_block_size=min_block_size,
)
return trt_model
if __name__ == "__main__":
precision = "BF16"
min_block_size = 1
batch_size = 1
seq_len = 28160
num_attention_heads = 32
attention_head_dim = 128
enable_pytorch_run = True
debug = False
device = "cuda"
hidden_size = num_attention_heads * attention_head_dim
with torch.inference_mode():
model = CosmosLearnablePositionalEmbed(
hidden_size=hidden_size,
max_size=(128, 240, 240),
patch_size=(1, 2, 2),
).to(device)
# Convert model to the appropriate precision
model = model.to(torch.bfloat16)
input_dtype = torch.bfloat16
# Prepare input for benchmarking or evaluation
hidden_states = torch.randn(
1, 17, 16, 88, 160, dtype=input_dtype
).to(device)
# Pyt
pyt_output = model(hidden_states)
print("PyTorch output shape:", pyt_output.shape)
print("Pytorch output:", pyt_output.flatten())
# Compile the model with Torch-TensorRT
trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
# trt_model = torch.compile(
# model,
# backend="torch_tensorrt",
# options={
# "enabled_precisions": {input_dtype},
# "use_python_runtime": True,
# "min_block_size": min_block_size,
# },
# dynamic=None,
# )
trt_model = trt_model.to(device)
trt_output = trt_model(hidden_states)
print("TensorRT output shape:", trt_output.shape)
print("TensorRT output:", trt_output.flatten())
# Verify results match
diff = (pyt_output - trt_output).abs().max().item()
print(f"Max difference between PyTorch and TRT: {diff}")
# Check if results are close enough
tolerance = 0.01
if diff < tolerance:
print(f"✅ Results match! (difference: {diff} < {tolerance})")
else:
print(f"⚠️ Results differ! (difference: {diff} >= {tolerance})")
Here is the error message:
AttributeError: 'SymFloat' object has no attribute 'sqrt'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 215, in <module>
trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 146, in compile_torchtrt
ep = export_attention(model, hidden_states)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 134, in export_attention
ep = torch.export.export(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 311, in export
raise e
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 277, in export
return _export(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1229, in wrapper
raise e
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1195, in wrapper
ep = fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2336, in _export
ep = _export_for_training(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1229, in wrapper
raise e
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1195, in wrapper
ep = fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2144, in _export_for_training
export_artifact = export_func(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2075, in _non_strict_export
aten_export_artifact = _to_aten_func(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1863, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1993, in _aot_export_non_strict
gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1775, in _make_fx_helper
gm = make_fx(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2581, in wrapped
return make_fx_tracer.trace(f, *args)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2508, in trace
return self._trace_inner(f, *args)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2470, in _trace_inner
t = dispatch_trace(
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/_compile.py", line 54, in inner
return disable_fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1129, in _fn
return fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1443, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2062, in trace
res = super().trace(root, concrete_args)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1129, in _fn
return fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 869, in trace
(self.create_arg(fn(*args)),),
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1501, in wrapped
out = f(*tensors) # type:ignore[call-arg]
File "<string>", line 1, in <lambda>
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1662, in wrapped_fn
return tuple(flat_fn(*args))
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
tree_out = fn(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1369, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 844, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 837, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1977, in forward
tree_out = mod(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 844, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
ret_val = forward(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 837, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 117, in forward
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
TypeError: loop of ufunc does not support argument 0 of type SymFloat which has no callable sqrt method
To Reproduce
Steps to reproduce the behavior:
- Run Python script above
Expected behavior
Passed and the Torch-TRT output matches the Torch output.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
- PyTorch Version (e.g. 1.0): 2.9.0
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
- Are you using local sources or building from archives: No
- Python version: 3.10
- CUDA version: 12.9
- GPU models and configuration: Nvidia B200
- Any other relevant information: None
Additional context
The WAR is to use a torch tensor instead of numpy
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working