From c4de771222029497a6fad6b42258b640b82aa87b Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 18 Apr 2024 17:39:03 -0700 Subject: [PATCH] addressing review comments and changing test names --- .../dynamo/conversion/aten_ops_converters.py | 23 ++--- .../dynamo/conversion/impl/select.py | 86 +++---------------- .../py/dynamo/conversion/test_scatter_aten.py | 11 +-- 3 files changed, 26 insertions(+), 94 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5d20d12cc8..60e0ab4e9c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,28 +691,21 @@ def aten_ops_clamp( ) -@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) -def aten_ops_scatter_value( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.scatter_value( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] - ) - - +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) @dynamo_tensorrt_converter(torch.ops.aten.scatter.src) -def aten_ops_scatter_src( +@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) +def aten_ops_scatter( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.scatter_src( + return impl.select.scatter( ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6603a5d280..232bada26a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -393,7 +393,7 @@ def index_select( return gather_layer.get_output(0) -def scatter_value( +def scatter( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -401,92 +401,30 @@ def scatter_value( input: TRTTensor, dim: int, index: Union[TRTTensor, np.ndarray, torch.Tensor], - value: float, + src: Union[TRTTensor, int, float], ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"scatter_tensor received input {input} that is not part " - "of the TensorRT region!" - ) input_shape = input.shape index_shape = index.shape index_shape_list = list(index.shape) if not (isinstance(index, TRTTensor)): index = get_trt_tensor(ctx, index, f"_index_tensor") - if len(input_shape) != len(index_shape): - raise RuntimeError(f"The no of dimensions of input and index should be equal") dim = get_positive_dim(dim, len(input_shape)) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - input_dims = len(input_shape) - for i in range(0, input_dims): - if i != dim and (index_shape[i] >= input.shape[i]): - raise RuntimeError( - f"cannot have index size greater than the input size along dimension {dim}" - ) - - value_tensor = get_trt_tensor( - ctx, value * torch.ones(index_shape_list), name + "_value_tensor" - ) - value_tensor = cast_trt_tensor( - ctx, value_tensor, input.dtype, name + "_cast_value_tensor" - ) - scatter_layer = ctx.net.add_scatter( - input, index, value_tensor, trt.ScatterMode.ELEMENT - ) - scatter_layer.axis = dim - set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) - out = scatter_layer.get_output(0) - return out - - -def scatter_src( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - dim: Shape, - index: Shape, - src: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"scatter_tensor received input {input} that is not part " - "of the TensorRT region!" - ) - input_shape = input.shape - index_shape = index.shape - src_shape = src.shape - if not (isinstance(index, TRTTensor)): - index = get_trt_tensor(ctx, index, f"_index_tensor") - if len(input_shape) != len(index_shape): - raise RuntimeError(f"The no of dimensions of input and index should be equal") - if len(index_shape) != len(src_shape): - raise RuntimeError(f"The no of dimensions of src and index should be equal") - - input_dims = len(input_shape) - dim = get_positive_dim(cast(int, dim), input_dims) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - - for i in range(0, input_dims): - if i != dim and (index_shape[i] >= input.shape[i]): - raise RuntimeError( - f"cannot have index size greater than the input size along dimension {dim}" - ) - input_dtype = input.dtype - # required for cases where src is a constant - src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT) - if input_dtype != src_dtype: - raise RuntimeError(f"The type of input and src should be made") src_tensor = src - if not (isinstance(src, TRTTensor)): + # scatter.value + if isinstance(src, int) or isinstance(src, float): + src_tensor = get_trt_tensor( + ctx, src * torch.ones(index_shape_list), name + "_value_tensor" + ) + src_tensor = cast_trt_tensor( + ctx, src_tensor, input.dtype, name + "_cast_value_tensor" + ) + # scatter.src + elif not (isinstance(src, TRTTensor)): src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") scatter_layer = ctx.net.add_scatter( diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index 6a2a398907..aff4b9fe5e 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -1,9 +1,10 @@ import torch -from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from .harness import DispatchTestCase + class TestScatterValueConverter(DispatchTestCase): @parameterized.expand( @@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase): @parameterized.expand( [ ( - "scatter_zero_dim_indexOne_constant_src", + "scatter_zero_dim_indexOne_src", 0, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), ), ( - "scatter_zero_dim_indexTwo_constant_src", + "scatter_zero_dim_indexTwo_src", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), ), ( - "scatter_one_dim_indexOne_constant_src", + "scatter_one_dim_indexOne_src", 1, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), ), ( - "scatter_one_dim_indexTwo_constant_src", + "scatter_one_dim_indexTwo_src", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),