From 068d0692147b3e5d91cb2c0a5b98e6e78a3d3b17 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 12 Nov 2025 14:47:37 -0800 Subject: [PATCH 1/2] fix L0 RTX test issues --- .../dynamo/conversion/aten_ops_converters.py | 8 ++++++-- tests/py/dynamo/conversion/test_index_aten.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 23c5287cc2..147813d8e0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -8,6 +8,7 @@ import torch from tensorrt import ITensor as TRTTensor from torch.fx.node import Argument, Node, Target +from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._features import needs_not_tensorrt_rtx from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor from torch_tensorrt.dynamo._settings import CompilationSettings @@ -427,8 +428,8 @@ def index_dtype_validator( def index_nonbool_validator( node: Node, settings: Optional[CompilationSettings] = None ) -> bool: - # for thor, we don't support boolean indices - if is_thor(): + # for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported + if is_thor() or ENABLED_FEATURES.tensorrt_rtx: index = node.args[1] for ind in index: if ind is not None: @@ -903,6 +904,8 @@ def aten_ops_select( @dynamo_tensorrt_converter( torch.ops.aten.index_put.default, + capability_validator=lambda node, settings: index_dtype_validator(node, settings) + and index_nonbool_validator(node, settings), supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -2786,6 +2789,7 @@ def aten_ops_max_pool( @dynamo_tensorrt_converter( torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True ) +@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 5a5c971367..e34dc48dd5 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -5,7 +5,7 @@ import torch_tensorrt from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input +from torch_tensorrt import ENABLED_FEATURES, Input from torch_tensorrt._utils import is_tegra_platform, is_thor from .harness import DispatchTestCase @@ -114,8 +114,8 @@ def forward(self, input): ] ) @unittest.skipIf( - is_thor(), - "Skipped on Thor due to nonzero not supported", + is_thor() or ENABLED_FEATURES.tensorrt_rtx, + "Skipped on Thor or tensorrt_rtx due to nonzero not supported", ) def test_index_constant_bool_mask(self, _, index, input): class TestModule(torch.nn.Module): @@ -148,6 +148,10 @@ def forward(self, x, index0): [input, index0], ) + @unittest.skipIf( + is_thor() or ENABLED_FEATURES.tensorrt_rtx, + "Skipped on Thor or tensorrt_rtx due to nonzero not supported", + ) def test_index_zero_two_dim_ITensor_mask(self): class TestModule(nn.Module): def forward(self, x, index0): @@ -176,8 +180,8 @@ def forward(self, x, index0): self.run_test(TestModule(), [input, index0]) @unittest.skipIf( - is_thor(), - "Skipped on Thor due to nonzero not supported", + is_thor() or ENABLED_FEATURES.tensorrt_rtx, + "Skipped on Thor or tensorrt_rtx due to nonzero not supported", ) def test_index_zero_index_three_dim_mask_ITensor(self): class TestModule(nn.Module): From a8d02f0850045e8c0b5e8713371265d2584e6079 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 12 Nov 2025 15:06:49 -0800 Subject: [PATCH 2/2] fix for rtx l0 test --- tests/py/dynamo/lowering/test_decompositions.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 32bf7f8b98..19169f68e0 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1679,15 +1679,8 @@ def forward(self, x): ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - f"Log_softmax TRT outputs don't match with the original model.", + assert torch.allclose( + optimized_model_results, torch_model_results, atol=1e-3, rtol=1e-3 ) @parameterized.expand(