Skip to content

Commit cfef1fe

Browse files
committed
feat: Add support for SymFloat input and truediv converter
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent c9b8bf5 commit cfef1fe

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,7 @@ def aten_ops_sub(
20262026
)
20272027

20282028

2029+
@dynamo_tensorrt_converter(operator.truediv, supports_dynamic_shapes=True)
20292030
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True)
20302031
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True)
20312032
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
102102
is_shape_tensor=True,
103103
)
104104
)
105+
elif isinstance(input_meta, torch.SymFloat):
106+
torchtrt_inputs.append(
107+
get_input(
108+
[1],
109+
torch.float32,
110+
name=input.name,
111+
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
112+
)
113+
)
105114
else:
106115
raise ValueError(
107116
f"The meta val for input node {input.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt"

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def unwrap_tensor_shape(
455455
tensor_shape.append(min_max_opt[mode])
456456
else:
457457
tensor_shape.append((min_max_opt["min"], min_max_opt["max"]))
458+
elif isinstance(tensor, torch.SymFloat):
459+
# SymFloats can be an input to graph sometimes. We register their shape as [1] to avoid errors.
460+
tensor_shape.append(1)
458461
elif isinstance(tensor, (torch.Tensor, FakeTensor)):
459462
for dimension in tensor.shape:
460463
tensor_shape.extend(unwrap_tensor_shape(dimension, mode=mode))
@@ -472,6 +475,8 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
472475
return torch.tensor(tensor).dtype
473476
elif isinstance(tensor, torch.SymInt):
474477
return torch.int64
478+
elif isinstance(tensor, torch.SymFloat):
479+
return torch.float32
475480
elif tensor is None:
476481
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
477482
return None

tests/py/dynamo/models/test_models.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import importlib
33
import platform
44
import unittest
5+
from typing import Optional
56

67
import pytest
78
import torch
9+
import torch.nn as nn
810
import torch_tensorrt as torchtrt
911
from torch_tensorrt.dynamo.utils import (
1012
COSINE_THRESHOLD,
@@ -420,6 +422,104 @@ def test_resnet18_half(ir):
420422
torch._dynamo.reset()
421423

422424

425+
@pytest.mark.unit
426+
def test_cosmos_true_div(ir):
427+
class CosmosLearnablePositionalEmbed(torch.nn.Module):
428+
def __init__(
429+
self,
430+
hidden_size: int,
431+
max_size: tuple[int, int, int],
432+
patch_size: tuple[int, int, int],
433+
eps: float = 1e-6,
434+
) -> None:
435+
super().__init__()
436+
437+
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
438+
self.patch_size = patch_size
439+
self.eps = eps
440+
441+
self.pos_emb_t = nn.Parameter(torch.randn(self.max_size[0], hidden_size))
442+
self.pos_emb_h = nn.Parameter(torch.randn(self.max_size[1], hidden_size))
443+
self.pos_emb_w = nn.Parameter(torch.randn(self.max_size[2], hidden_size))
444+
445+
def forward(
446+
self,
447+
hidden_states: torch.Tensor,
448+
num_ranks: Optional[int] = None,
449+
rank_id: Optional[torch.Tensor] = None,
450+
) -> torch.Tensor:
451+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
452+
pe_size = [
453+
num_frames // self.patch_size[0],
454+
height // self.patch_size[1],
455+
width // self.patch_size[2],
456+
]
457+
if num_ranks is not None and rank_id is not None:
458+
pe_size[0] = pe_size[0] * num_ranks
459+
460+
# Use expand() instead of repeat() - torch_tensorrt compatible
461+
# expand() creates a view without copying data, better for dynamic shapes
462+
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand(
463+
batch_size, -1, pe_size[1], pe_size[2], -1
464+
)
465+
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand(
466+
batch_size, pe_size[0], -1, pe_size[2], -1
467+
)
468+
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand(
469+
batch_size, pe_size[0], pe_size[1], -1, -1
470+
)
471+
emb = emb_t + emb_h + emb_w
472+
emb = emb.flatten(1, 3)
473+
474+
norm = torch.linalg.vector_norm(
475+
emb, dim=-1, keepdim=True, dtype=torch.float32
476+
)
477+
alpha = (norm.numel() / emb.numel()) ** 0.5
478+
# hidden_size = emb.shape[-1]
479+
# alpha = (1.0 / hidden_size) ** 0.5
480+
norm = torch.add(self.eps, norm, alpha=alpha)
481+
return (emb / norm).type_as(hidden_states)
482+
483+
with torch.no_grad():
484+
hidden_states = torch.randn(1, 16, 16, 88, 160).cuda()
485+
model = CosmosLearnablePositionalEmbed(
486+
hidden_size=4096,
487+
max_size=(128, 240, 240),
488+
patch_size=(1, 2, 2),
489+
)
490+
model.eval().cuda()
491+
pyt_output = model(hidden_states)
492+
num_latent_frames = torch.export.Dim("num_latent_frames", min=1, max=16)
493+
494+
ep = torch.export.export(
495+
model,
496+
args=(hidden_states,),
497+
dynamic_shapes=({2: num_latent_frames},), # Make dimension 2 dynamic
498+
strict=False,
499+
)
500+
trt_model = torchtrt.dynamo.compile(
501+
ep,
502+
inputs=(hidden_states,),
503+
enabled_precisions={torch.bfloat16},
504+
use_explicit_typing=False,
505+
use_fp32_acc=False,
506+
device="cuda:0",
507+
disable_tf32=True,
508+
use_python_runtime=True,
509+
min_block_size=1,
510+
)
511+
trt_output = trt_model(hidden_states)
512+
513+
cos_sim = cosine_similarity(pyt_output, trt_output)
514+
assertions.assertTrue(
515+
cos_sim > COSINE_THRESHOLD,
516+
msg=f"Cosmos Learnable Positional Embed TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
517+
)
518+
519+
# Clean up model env
520+
torch._dynamo.reset()
521+
522+
423523
@pytest.mark.unit
424524
@unittest.skipIf(
425525
torchtrt.ENABLED_FEATURES.tensorrt_rtx,

0 commit comments

Comments
 (0)