Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt._features import needs_not_tensorrt_rtx
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -24,6 +24,7 @@
get_positive_dim,
is_only_operator_on_placeholder,
)
from torch_tensorrt._utils import is_thor

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -424,9 +425,24 @@ def index_dtype_validator(
return True


def index_nonbool_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# for thor, we don't support boolean indices
if is_thor():
index = node.args[1]
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype == torch.bool:
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings),
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
Expand Down Expand Up @@ -3601,10 +3617,18 @@ def aten_ops_full(
)


def nonzero_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
return not is_thor()


# currently nonzero is not supported for tensorrt_rtx
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
@dynamo_tensorrt_converter(
torch.ops.aten.nonzero.default,
capability_validator=nonzero_validator,
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
Expand Down
9 changes: 8 additions & 1 deletion tests/core/conversion/converters/test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,14 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
%3 : bool = prim::Constant[value=1]()
%5 : Tensor = aten::any(%0, %1, %3)
return (%5))IR";
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA);
std::vector<int> data(64, 0);
for (int i = 0; i < 64; ++i) {
if (i % 7 == 0)
data[i] = 1; // some positives
if (i % 13 == 0)
data[i] = -1; // some negatives
}
auto in = at::tensor(data, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)).reshape({2, 32}); // shape [2, 32]
test_body(graph, in);
}

Expand Down
Loading