Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand Down Expand Up @@ -427,8 +428,8 @@ def index_dtype_validator(
def index_nonbool_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# for thor, we don't support boolean indices
if is_thor():
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
if is_thor() or ENABLED_FEATURES.tensorrt_rtx:
index = node.args[1]
for ind in index:
if ind is not None:
Expand Down Expand Up @@ -903,6 +904,8 @@ def aten_ops_select(

@dynamo_tensorrt_converter(
torch.ops.aten.index_put.default,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings),
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
Expand Down Expand Up @@ -2786,6 +2789,7 @@ def aten_ops_max_pool(
@dynamo_tensorrt_converter(
torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
14 changes: 9 additions & 5 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from torch_tensorrt import ENABLED_FEATURES, Input
from torch_tensorrt._utils import is_tegra_platform, is_thor

from .harness import DispatchTestCase
Expand Down Expand Up @@ -114,8 +114,8 @@ def forward(self, input):
]
)
@unittest.skipIf(
is_thor(),
"Skipped on Thor due to nonzero not supported",
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
)
def test_index_constant_bool_mask(self, _, index, input):
class TestModule(torch.nn.Module):
Expand Down Expand Up @@ -148,6 +148,10 @@ def forward(self, x, index0):
[input, index0],
)

@unittest.skipIf(
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
)
def test_index_zero_two_dim_ITensor_mask(self):
class TestModule(nn.Module):
def forward(self, x, index0):
Expand Down Expand Up @@ -176,8 +180,8 @@ def forward(self, x, index0):
self.run_test(TestModule(), [input, index0])

@unittest.skipIf(
is_thor(),
"Skipped on Thor due to nonzero not supported",
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
)
def test_index_zero_index_three_dim_mask_ITensor(self):
class TestModule(nn.Module):
Expand Down
11 changes: 2 additions & 9 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,15 +1679,8 @@ def forward(self, x):
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Log_softmax TRT outputs don't match with the original model.",
assert torch.allclose(
optimized_model_results, torch_model_results, atol=1e-3, rtol=1e-3
)

@parameterized.expand(
Expand Down
Loading