Skip to content

Commit

Permalink
Addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed May 16, 2024
1 parent 0318b33 commit 2d2d214
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 19 deletions.
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,13 +691,14 @@ def aten_ops_clamp(
)


@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
def aten_ops_scatter(
ctx: ConversionContext,
target: Target,
Expand Down
20 changes: 4 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -405,26 +404,15 @@ def scatter(
) -> TRTTensor:
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index.shape)
if not (isinstance(index, TRTTensor)):
if isinstance(index, torch.Tensor):
if index.dtype == torch.int64:
index = index.to(torch.int32)
elif isinstance(index, np.ndarray):
if index.dtype == np.int64:
index = index.astype(np.int32)
index = get_trt_tensor(ctx, index, f"_index_tensor")
index_shape_list = list(index_shape)
if index.dtype == trt.int64:
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
dim = get_positive_dim(dim, len(input_shape))
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

src_tensor = src
# scatter.value
if isinstance(src, int) or isinstance(src, float):
src_tensor = get_trt_tensor(
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
)
src_tensor = cast_trt_tensor(
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
Expand Down
1 change: 0 additions & 1 deletion tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import logging
import time
import unittest
Expand Down

0 comments on commit 2d2d214

Please sign in to comment.