From 628fab7cbee2039029f364f74f9227035fd4d49b Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 19 Jan 2024 10:20:46 -0800 Subject: [PATCH 01/10] aten::select --- .../dynamo/conversion/impl/select.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 2ec6420e0b..44ca21a1f9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -390,3 +390,76 @@ def index_select( set_layer_name(gather_layer, target, f"{name}_gather", source_ir) return gather_layer.get_output(0) + + +def scatter_value( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Shape, + index: Shape, + value: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"scatter_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + + input_dims = len(input.shape) + for i in range(0, input_dims): + if index[i] >= input.shape[i]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input.shape[dim]}" + ) + value_tensor = value * torch.ones(index.shape) + scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT) + scatter_layer.set_axis(dim) + set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) + out = scatter_layer.get_output(0) + return out + + +def scatter_src( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Shape, + index: Shape, + src: float, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"scatter_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + + input_dims = len(input.shape) + for i in range(0, input_dims): + if index[i] >= input.shape[i]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input.shape[dim]}" + ) + scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT) + scatter_layer.set_axis(dim) + set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) + out = scatter_layer.get_output(0) + return out From 6fbc0ec9a2cd38f4c566671c28eaa26241f99026 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 27 Feb 2024 12:26:30 -0800 Subject: [PATCH 02/10] scatter_value and scatter_src converter --- .../dynamo/conversion/aten_ops_converters.py | 38 ++++++++++++ .../dynamo/conversion/impl/select.py | 35 +++++++---- .../py/dynamo/conversion/test_scatter_aten.py | 59 +++++++++++++++++++ 3 files changed, 122 insertions(+), 10 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_scatter_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..d3bdba61d3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,6 +691,44 @@ def aten_ops_clamp( ) +@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) +def aten_ops_scatter_value( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.scatter_value( + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.scatter.src) +def aten_ops_scatter_src( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.scatter_src( + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) + + +def aten_ops_select( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.select( + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) + + @dynamo_tensorrt_converter(torch.ops.aten.select.int) def aten_ops_select( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 44ca21a1f9..9ed15428e6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -400,20 +400,25 @@ def scatter_value( input: TRTTensor, dim: Shape, index: Shape, - value: TRTTensor, + value: float, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( f"scatter_tensor received input {input} that is not part " "of the TensorRT region!" ) - - ranks = len(input.shape) + input_shape = input.shape + index_shape = index.shape + if (len(input_shape) != len(index_shape)): + raise RuntimeError( + f"The no of dimensions of input and index should be equal" + ) + ranks = len(input_shape) dim = get_positive_dim(cast(int, dim), ranks) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" input_dims = len(input.shape) for i in range(0, input_dims): @@ -437,22 +442,32 @@ def scatter_src( input: TRTTensor, dim: Shape, index: Shape, - src: float, + src: TRTTensor, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( f"scatter_tensor received input {input} that is not part " "of the TensorRT region!" ) - - ranks = len(input.shape) - dim = get_positive_dim(cast(int, dim), ranks) + input_shape = input.shape + index_shape = index.shape + src_shape = src.shape + if (len(input_shape) != len(index_shape)): + raise RuntimeError( + f"The no of dimensions of input and index should be equal" + ) + if (len(index_shape) != len(src_shape)): + raise RuntimeError( + f"The no of dimensions of src and index should be equal" + ) + + input_dims = len(input_shape) + dim = get_positive_dim(cast(int, dim), input_dims) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - input_dims = len(input.shape) for i in range(0, input_dims): if index[i] >= input.shape[i]: raise RuntimeError( diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py new file mode 100644 index 0000000000..666a12dea9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -0,0 +1,59 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestScatterValueConverter(DispatchTestCase): + @parameterized.expand( + [ + ("scatter_zero_dim_indexOne_value", 0, [[0, 1, 2, 0]], 1), + ("scatter_zero_dim_indexTwo_value", 0, [[0, 1, 2, 0], [1, 2, 1, 1]], 1), + ("scatter_one_dim_indexOne_value", 1, [[0, 1, 2, 0]], 1), + ("scatter_one_dim_indexTwo_value", 1, [[0, 1, 2, 0], [1, 2, 1, 1]], 1), + ] + ) + def test_scatter(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, src): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = [torch.zeros(3, 5, dtype = torch.int32)] + self.run_test( + TestModule(), + input, + ) + + +class TestScatterSrcConverter(DispatchTestCase): + @parameterized.expand( + [ + ("scatter_zero_dim_indexOne", 0, [[0, 1, 2, 0]]), + ("scatter_zero_dim_indexTwo", 0, [[0, 1, 2, 0], [1, 2, 1, 1]]), + ("scatter_one_dim_indexOne", 1, [[0, 1, 2, 0]]), + ("scatter_one_dim_indexTwo", 1, [[0, 1, 2, 0], [1, 2, 1, 1]]), + ] + ) + def test_scatter(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, src): + return torch.ops.aten.scatter.src(input, dim, index, src) + + src = [torch.arange(1, 11).reshape((2,5))] + input = torch.zeros(3, 5, dtype = src.dtype) + inputs = [input, src] + self.run_test( + TestModule(), + inputs, + ) + + + \ No newline at end of file From bfd3498efc58f4c3bf79d3eaa0742090a1d72bbf Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 27 Feb 2024 13:54:06 -0800 Subject: [PATCH 03/10] Linting fix --- .../dynamo/conversion/impl/select.py | 32 +++++++++---------- .../py/dynamo/conversion/test_scatter_aten.py | 11 +++---- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 9ed15428e6..c1384f7855 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -409,17 +409,15 @@ def scatter_value( ) input_shape = input.shape index_shape = index.shape - if (len(input_shape) != len(index_shape)): - raise RuntimeError( - f"The no of dimensions of input and index should be equal" - ) + if len(input_shape) != len(index_shape): + raise RuntimeError(f"The no of dimensions of input and index should be equal") ranks = len(input_shape) dim = get_positive_dim(cast(int, dim), ranks) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - + input_dims = len(input.shape) for i in range(0, input_dims): if index[i] >= input.shape[i]: @@ -427,7 +425,9 @@ def scatter_value( f"cannot have index greater than the dimension length! {input.shape[dim]}" ) value_tensor = value * torch.ones(index.shape) - scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT) + scatter_layer = ctx.net.add_scatter( + input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT + ) scatter_layer.set_axis(dim) set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) out = scatter_layer.get_output(0) @@ -452,28 +452,26 @@ def scatter_src( input_shape = input.shape index_shape = index.shape src_shape = src.shape - if (len(input_shape) != len(index_shape)): - raise RuntimeError( - f"The no of dimensions of input and index should be equal" - ) - if (len(index_shape) != len(src_shape)): - raise RuntimeError( - f"The no of dimensions of src and index should be equal" - ) - + if len(input_shape) != len(index_shape): + raise RuntimeError(f"The no of dimensions of input and index should be equal") + if len(index_shape) != len(src_shape): + raise RuntimeError(f"The no of dimensions of src and index should be equal") + input_dims = len(input_shape) dim = get_positive_dim(cast(int, dim), input_dims) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - + for i in range(0, input_dims): if index[i] >= input.shape[i]: raise RuntimeError( f"cannot have index greater than the dimension length! {input.shape[dim]}" ) - scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT) + scatter_layer = ctx.net.add_scatter( + input, index, src, trt.tensorrt.ScatterModekELEMENT + ) scatter_layer.set_axis(dim) set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) out = scatter_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index 666a12dea9..fc161bf4c2 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -23,7 +23,7 @@ def __init__(self): def forward(self, input, src): return torch.ops.aten.scatter.value(input, dim, index, value) - input = [torch.zeros(3, 5, dtype = torch.int32)] + input = [torch.zeros(3, 5, dtype=torch.int32)] self.run_test( TestModule(), input, @@ -46,14 +46,11 @@ def __init__(self): def forward(self, input, src): return torch.ops.aten.scatter.src(input, dim, index, src) - - src = [torch.arange(1, 11).reshape((2,5))] - input = torch.zeros(3, 5, dtype = src.dtype) + + src = [torch.arange(1, 11).reshape((2, 5))] + input = torch.zeros(3, 5, dtype=src.dtype) inputs = [input, src] self.run_test( TestModule(), inputs, ) - - - \ No newline at end of file From 870b79f5fc4391d19667725d4a0e15ffed404692 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 6 Mar 2024 01:55:59 -0800 Subject: [PATCH 04/10] scatter adding test cases for scatter.value and scatter.src --- .../dynamo/conversion/aten_ops_converters.py | 16 +- .../dynamo/conversion/impl/select.py | 48 ++++-- tests/py/dynamo/conversion/harness.py | 47 +++-- .../py/dynamo/conversion/test_scatter_aten.py | 162 ++++++++++++++++-- 4 files changed, 216 insertions(+), 57 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d3bdba61d3..5d20d12cc8 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -700,7 +700,7 @@ def aten_ops_scatter_value( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.select.scatter_value( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] ) @@ -713,19 +713,7 @@ def aten_ops_scatter_src( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.select.scatter_src( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] - ) - - -def aten_ops_select( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.select( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c1384f7855..6603a5d280 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -21,6 +21,7 @@ set_layer_name, ) from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -398,8 +399,8 @@ def scatter_value( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Shape, - index: Shape, + dim: int, + index: Union[TRTTensor, np.ndarray, torch.Tensor], value: float, ) -> TRTTensor: if not isinstance(input, TRTTensor): @@ -409,26 +410,34 @@ def scatter_value( ) input_shape = input.shape index_shape = index.shape + index_shape_list = list(index.shape) + if not (isinstance(index, TRTTensor)): + index = get_trt_tensor(ctx, index, f"_index_tensor") if len(input_shape) != len(index_shape): raise RuntimeError(f"The no of dimensions of input and index should be equal") - ranks = len(input_shape) - dim = get_positive_dim(cast(int, dim), ranks) + dim = get_positive_dim(dim, len(input_shape)) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - input_dims = len(input.shape) + input_dims = len(input_shape) for i in range(0, input_dims): - if index[i] >= input.shape[i]: + if i != dim and (index_shape[i] >= input.shape[i]): raise RuntimeError( - f"cannot have index greater than the dimension length! {input.shape[dim]}" + f"cannot have index size greater than the input size along dimension {dim}" ) - value_tensor = value * torch.ones(index.shape) + + value_tensor = get_trt_tensor( + ctx, value * torch.ones(index_shape_list), name + "_value_tensor" + ) + value_tensor = cast_trt_tensor( + ctx, value_tensor, input.dtype, name + "_cast_value_tensor" + ) scatter_layer = ctx.net.add_scatter( - input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT + input, index, value_tensor, trt.ScatterMode.ELEMENT ) - scatter_layer.set_axis(dim) + scatter_layer.axis = dim set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) out = scatter_layer.get_output(0) return out @@ -452,6 +461,8 @@ def scatter_src( input_shape = input.shape index_shape = index.shape src_shape = src.shape + if not (isinstance(index, TRTTensor)): + index = get_trt_tensor(ctx, index, f"_index_tensor") if len(input_shape) != len(index_shape): raise RuntimeError(f"The no of dimensions of input and index should be equal") if len(index_shape) != len(src_shape): @@ -465,14 +476,23 @@ def scatter_src( assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" for i in range(0, input_dims): - if index[i] >= input.shape[i]: + if i != dim and (index_shape[i] >= input.shape[i]): raise RuntimeError( - f"cannot have index greater than the dimension length! {input.shape[dim]}" + f"cannot have index size greater than the input size along dimension {dim}" ) + input_dtype = input.dtype + # required for cases where src is a constant + src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT) + if input_dtype != src_dtype: + raise RuntimeError(f"The type of input and src should be made") + src_tensor = src + if not (isinstance(src, TRTTensor)): + src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") + scatter_layer = ctx.net.add_scatter( - input, index, src, trt.tensorrt.ScatterModekELEMENT + input, index, src_tensor, trt.ScatterMode.ELEMENT ) - scatter_layer.set_axis(dim) + scatter_layer.axis = dim set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) out = scatter_layer.get_output(0) return out diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 7bed68a11e..9407e3b694 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -1,5 +1,4 @@ -# type: ignore - +import copy import logging import time import unittest @@ -14,6 +13,9 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule @@ -50,16 +52,20 @@ def setUp(self): def run_test( self, mod, - inputs, + fx_inputs, + trt_interpreter_inputs, interpreter, rtol, atol, check_dtype=True, ): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) + cuda_fx_inputs = [] + cuda_trt_inputs = [] + for i in trt_interpreter_inputs: + cuda_trt_inputs.append(i.cuda()) + for i in fx_inputs: + cuda_fx_inputs.append(i.cuda()) mod.eval() start = time.perf_counter() @@ -73,13 +79,13 @@ def run_test( ) mod = mod.cuda() - ref_outputs = mod(*cuda_inputs) + ref_outputs = mod(*cuda_fx_inputs) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - outputs = trt_mod(*cuda_inputs) + outputs = trt_mod(*cuda_trt_inputs) end_event.record() torch.cuda.synchronize() _LOGGER.info( @@ -237,6 +243,25 @@ def run_test( debug=True, ) + num_inputs = len(inputs) + trt_inputs = inputs + for num_input in range(num_inputs): + input = inputs[num_input] + if input.dtype in (torch.int64, torch.float64): + dtype_32bit = ( + torch.int32 if (input.dtype == torch.int64) else torch.int64 + ) + # should we modify graph here to insert clone nodes? + # ideally not required + trt_inputs = ( + list(trt_inputs[:num_input]) + + [ + input.to(dtype_32bit), + ] + + list(trt_inputs[num_input + 1 :]) + ) + + trt_input_specs = [Input.from_tensor(i) for i in trt_inputs] input_specs = [Input.from_tensor(i) for i in inputs] output_dtypes = None @@ -245,7 +270,7 @@ def run_test( mod, input_specs, compilation_settings.device, - truncate_double=compilation_settings.truncate_double, + truncate_long_and_double=compilation_settings.truncate_long_and_double, ) _LOGGER.debug(f"Compilation settings: {compilation_settings}") @@ -254,13 +279,15 @@ def run_test( interp = TRTInterpreter( mod, - input_specs, + trt_input_specs, output_dtypes=output_dtypes, compilation_settings=compilation_settings, ) + super().run_test( mod, inputs, + trt_inputs, interp, rtol, atol, diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index fc161bf4c2..6a2a398907 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -1,56 +1,180 @@ import torch +from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from .harness import DispatchTestCase - class TestScatterValueConverter(DispatchTestCase): @parameterized.expand( [ - ("scatter_zero_dim_indexOne_value", 0, [[0, 1, 2, 0]], 1), - ("scatter_zero_dim_indexTwo_value", 0, [[0, 1, 2, 0], [1, 2, 1, 1]], 1), - ("scatter_one_dim_indexOne_value", 1, [[0, 1, 2, 0]], 1), - ("scatter_one_dim_indexTwo_value", 1, [[0, 1, 2, 0], [1, 2, 1, 1]], 1), + ( + "scatter_zero_dim_indexOne_constant_value", + 0, + torch.tensor([[0, 1, 2, 0]]), + 1, + ), + ( + "scatter_zero_dim_indexTwo_constant_value", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ( + "scatter_one_dim_indexOne_constant_value", + 1, + torch.tensor([[0, 1, 2, 0]]), + 1, + ), + ( + "scatter_one_dim_indexTwo_costant_value", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), ] ) - def test_scatter(self, _, dim, index, value): + def test_scatter_index_constant(self, _, dim, index, value): class TestModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, input, src): + def forward(self, input): return torch.ops.aten.scatter.value(input, dim, index, value) - input = [torch.zeros(3, 5, dtype=torch.int32)] + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] self.run_test( TestModule(), - input, + inputs, + ) + + @parameterized.expand( + [ + ("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1), + ( + "scatter_zero_dim_indexTwo_value", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1), + ( + "scatter_one_dim_indexTwo_value", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ] + ) + def test_scatter_index_input(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test( + TestModule(), + inputs, ) class TestScatterSrcConverter(DispatchTestCase): @parameterized.expand( [ - ("scatter_zero_dim_indexOne", 0, [[0, 1, 2, 0]]), - ("scatter_zero_dim_indexTwo", 0, [[0, 1, 2, 0], [1, 2, 1, 1]]), - ("scatter_one_dim_indexOne", 1, [[0, 1, 2, 0]]), - ("scatter_one_dim_indexTwo", 1, [[0, 1, 2, 0], [1, 2, 1, 1]]), + ( + "scatter_zero_dim_indexOne_constant_src", + 0, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_zero_dim_indexTwo_constant_src", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexOne_constant_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_constant_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + ), + # These are special cases where in the harness.py code might need to be changed to input cuda_inputs + # In that case below two test cases would also require index and src to be on cuda + # ("scatter_one_dim_indexOne_constant_src", 1, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32)), + # ("scatter_one_dim_indexTwo_constant_src", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32)), ] ) - def test_scatter(self, _, dim, index): + def test_scatter_index_constant(self, _, dim, index, src): class TestModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, input, src): + def forward(self, input): return torch.ops.aten.scatter.src(input, dim, index, src) - src = [torch.arange(1, 11).reshape((2, 5))] - input = torch.zeros(3, 5, dtype=src.dtype) - inputs = [input, src] + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + scatter = TestModule() self.run_test( TestModule(), inputs, ) + + @parameterized.expand( + [ + ( + "scatter_zero_dim_indexOne_constant_src", + 0, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_zero_dim_indexTwo_constant_src", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexOne_constant_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_constant_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + ), + ] + ) + def test_scatter_index_input(self, _, dim, index, src): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.src(input, dim, index, src) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test( + TestModule(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 812114ee2c7cc3dba1b4b482b5e3feb88f4873e6 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 18 Apr 2024 17:39:03 -0700 Subject: [PATCH 05/10] addressing review comments and changing test names --- .../dynamo/conversion/aten_ops_converters.py | 23 ++--- .../dynamo/conversion/impl/select.py | 86 +++---------------- .../py/dynamo/conversion/test_scatter_aten.py | 11 +-- 3 files changed, 26 insertions(+), 94 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5d20d12cc8..60e0ab4e9c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,28 +691,21 @@ def aten_ops_clamp( ) -@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) -def aten_ops_scatter_value( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.scatter_value( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] - ) - - +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) @dynamo_tensorrt_converter(torch.ops.aten.scatter.src) -def aten_ops_scatter_src( +@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) +def aten_ops_scatter( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.scatter_src( + return impl.select.scatter( ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6603a5d280..232bada26a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -393,7 +393,7 @@ def index_select( return gather_layer.get_output(0) -def scatter_value( +def scatter( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -401,92 +401,30 @@ def scatter_value( input: TRTTensor, dim: int, index: Union[TRTTensor, np.ndarray, torch.Tensor], - value: float, + src: Union[TRTTensor, int, float], ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"scatter_tensor received input {input} that is not part " - "of the TensorRT region!" - ) input_shape = input.shape index_shape = index.shape index_shape_list = list(index.shape) if not (isinstance(index, TRTTensor)): index = get_trt_tensor(ctx, index, f"_index_tensor") - if len(input_shape) != len(index_shape): - raise RuntimeError(f"The no of dimensions of input and index should be equal") dim = get_positive_dim(dim, len(input_shape)) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - input_dims = len(input_shape) - for i in range(0, input_dims): - if i != dim and (index_shape[i] >= input.shape[i]): - raise RuntimeError( - f"cannot have index size greater than the input size along dimension {dim}" - ) - - value_tensor = get_trt_tensor( - ctx, value * torch.ones(index_shape_list), name + "_value_tensor" - ) - value_tensor = cast_trt_tensor( - ctx, value_tensor, input.dtype, name + "_cast_value_tensor" - ) - scatter_layer = ctx.net.add_scatter( - input, index, value_tensor, trt.ScatterMode.ELEMENT - ) - scatter_layer.axis = dim - set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) - out = scatter_layer.get_output(0) - return out - - -def scatter_src( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - dim: Shape, - index: Shape, - src: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"scatter_tensor received input {input} that is not part " - "of the TensorRT region!" - ) - input_shape = input.shape - index_shape = index.shape - src_shape = src.shape - if not (isinstance(index, TRTTensor)): - index = get_trt_tensor(ctx, index, f"_index_tensor") - if len(input_shape) != len(index_shape): - raise RuntimeError(f"The no of dimensions of input and index should be equal") - if len(index_shape) != len(src_shape): - raise RuntimeError(f"The no of dimensions of src and index should be equal") - - input_dims = len(input_shape) - dim = get_positive_dim(cast(int, dim), input_dims) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - - for i in range(0, input_dims): - if i != dim and (index_shape[i] >= input.shape[i]): - raise RuntimeError( - f"cannot have index size greater than the input size along dimension {dim}" - ) - input_dtype = input.dtype - # required for cases where src is a constant - src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT) - if input_dtype != src_dtype: - raise RuntimeError(f"The type of input and src should be made") src_tensor = src - if not (isinstance(src, TRTTensor)): + # scatter.value + if isinstance(src, int) or isinstance(src, float): + src_tensor = get_trt_tensor( + ctx, src * torch.ones(index_shape_list), name + "_value_tensor" + ) + src_tensor = cast_trt_tensor( + ctx, src_tensor, input.dtype, name + "_cast_value_tensor" + ) + # scatter.src + elif not (isinstance(src, TRTTensor)): src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") scatter_layer = ctx.net.add_scatter( diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index 6a2a398907..aff4b9fe5e 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -1,9 +1,10 @@ import torch -from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from .harness import DispatchTestCase + class TestScatterValueConverter(DispatchTestCase): @parameterized.expand( @@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase): @parameterized.expand( [ ( - "scatter_zero_dim_indexOne_constant_src", + "scatter_zero_dim_indexOne_src", 0, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), ), ( - "scatter_zero_dim_indexTwo_constant_src", + "scatter_zero_dim_indexTwo_src", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), ), ( - "scatter_one_dim_indexOne_constant_src", + "scatter_one_dim_indexOne_src", 1, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), ), ( - "scatter_one_dim_indexTwo_constant_src", + "scatter_one_dim_indexTwo_src", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), From 6afbb8d27dc5dec9399eed44dbd2dd378c287ecf Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 2 May 2024 10:41:23 -0700 Subject: [PATCH 06/10] uncommenting the tests --- tests/py/dynamo/conversion/harness.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 9407e3b694..c1b3708c92 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -249,7 +249,7 @@ def run_test( input = inputs[num_input] if input.dtype in (torch.int64, torch.float64): dtype_32bit = ( - torch.int32 if (input.dtype == torch.int64) else torch.int64 + torch.int32 if (input.dtype == torch.int64) else torch.float32 ) # should we modify graph here to insert clone nodes? # ideally not required @@ -259,7 +259,7 @@ def run_test( input.to(dtype_32bit), ] + list(trt_inputs[num_input + 1 :]) - ) + ) trt_input_specs = [Input.from_tensor(i) for i in trt_inputs] input_specs = [Input.from_tensor(i) for i in inputs] @@ -270,7 +270,7 @@ def run_test( mod, input_specs, compilation_settings.device, - truncate_long_and_double=compilation_settings.truncate_long_and_double, + truncate_double=compilation_settings.truncate_double, ) _LOGGER.debug(f"Compilation settings: {compilation_settings}") From 19049bc56d8181466d1af21faf04cebd83b822f9 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 2 May 2024 14:51:45 -0700 Subject: [PATCH 07/10] Removing the int64 casting to int32 in harness.py since native int64 in supported in TRT10 --- tests/py/dynamo/conversion/harness.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index c1b3708c92..f0bca786bc 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -247,10 +247,8 @@ def run_test( trt_inputs = inputs for num_input in range(num_inputs): input = inputs[num_input] - if input.dtype in (torch.int64, torch.float64): - dtype_32bit = ( - torch.int32 if (input.dtype == torch.int64) else torch.float32 - ) + if input.dtype is torch.float64: + dtype_32bit = torch.float32 # should we modify graph here to insert clone nodes? # ideally not required trt_inputs = ( From 0449068a8d2eda2b183b40ef10242a72aaa5f767 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 May 2024 14:56:33 -0700 Subject: [PATCH 08/10] Dynamo test cases error correction and adding support for int64 indices conversion to int32 in TRT10 for tests --- .../dynamo/conversion/impl/select.py | 6 ++++ tests/py/dynamo/conversion/harness.py | 16 +++++++-- .../py/dynamo/conversion/test_scatter_aten.py | 34 +++++++++---------- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 232bada26a..bd1257a8d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -407,6 +407,12 @@ def scatter( index_shape = index.shape index_shape_list = list(index.shape) if not (isinstance(index, TRTTensor)): + if isinstance(index, torch.Tensor): + if index.dtype == torch.int64: + index = index.to(torch.int32) + elif isinstance(index, np.ndarray): + if index.dtype == np.int64: + index = index.astype(np.int32) index = get_trt_tensor(ctx, index, f"_index_tensor") dim = get_positive_dim(dim, len(input_shape)) dynamic_shape = has_dynamic_shape(input.shape) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index f0bca786bc..24737c0a6b 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -226,6 +226,7 @@ def run_test( check_dtype=True, use_dynamo_tracer=False, enable_passes=False, + int32_reqd=False, ): mod.eval() mod = self.generate_graph( @@ -245,10 +246,19 @@ def run_test( num_inputs = len(inputs) trt_inputs = inputs + dtype_to_change = [] + if int32_reqd: + dtype_to_change = [torch.int64, torch.float64] + else: + dtype_to_change = [ + torch.float64, + ] for num_input in range(num_inputs): input = inputs[num_input] - if input.dtype is torch.float64: - dtype_32bit = torch.float32 + if input.dtype in dtype_to_change: + dtype_32bit = ( + torch.float32 if (input.dtype == torch.float64) else torch.int32 + ) # should we modify graph here to insert clone nodes? # ideally not required trt_inputs = ( @@ -360,4 +370,4 @@ def run_test_with_dynamic_shape( # Since the lowering is based on optimal shape. We need to test with # different shape(for ex. max shape) for testing dynamic shape inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] - super().run_test(mod, inputs_max, interp, rtol, atol) + super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol) diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index aff4b9fe5e..d4c4d02fb8 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -45,10 +45,7 @@ def forward(self, input): input = torch.zeros(3, 5, dtype=torch.int32) inputs = [input] - self.run_test( - TestModule(), - inputs, - ) + self.run_test(TestModule(), inputs, int32_reqd=True) @parameterized.expand( [ @@ -78,10 +75,7 @@ def forward(self, input, index): input = torch.zeros(3, 5, dtype=torch.int32) inputs = [input, index] - self.run_test( - TestModule(), - inputs, - ) + self.run_test(TestModule(), inputs, int32_reqd=True) class TestScatterSrcConverter(DispatchTestCase): @@ -113,8 +107,18 @@ class TestScatterSrcConverter(DispatchTestCase): ), # These are special cases where in the harness.py code might need to be changed to input cuda_inputs # In that case below two test cases would also require index and src to be on cuda - # ("scatter_one_dim_indexOne_constant_src", 1, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32)), - # ("scatter_one_dim_indexTwo_constant_src", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32)), + ( + "scatter_one_dim_indexOne_constant_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_constant_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), ] ) def test_scatter_index_constant(self, _, dim, index, src): @@ -128,10 +132,7 @@ def forward(self, input): input = torch.zeros(3, 5, dtype=torch.int32) inputs = [input] scatter = TestModule() - self.run_test( - TestModule(), - inputs, - ) + self.run_test(TestModule(), inputs, int32_reqd=True) @parameterized.expand( [ @@ -171,10 +172,7 @@ def forward(self, input, index): input = torch.zeros(3, 5, dtype=torch.int32) inputs = [input, index] - self.run_test( - TestModule(), - inputs, - ) + self.run_test(TestModule(), inputs, int32_reqd=True) if __name__ == "__main__": From 0318b33b12a7a3aedafe4adcb078c5fdc1e80db7 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 10 May 2024 12:44:20 -0700 Subject: [PATCH 09/10] code cleanup --- tests/py/dynamo/conversion/harness.py | 3 --- tests/py/dynamo/conversion/test_scatter_aten.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 24737c0a6b..ff90115550 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -13,9 +13,6 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - DYNAMO_CONVERTERS as CONVERTERS, -) from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index d4c4d02fb8..85b7094a5c 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -105,8 +105,6 @@ class TestScatterSrcConverter(DispatchTestCase): torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), ), - # These are special cases where in the harness.py code might need to be changed to input cuda_inputs - # In that case below two test cases would also require index and src to be on cuda ( "scatter_one_dim_indexOne_constant_src", 1, From 487906d72990f92d9cf2706d874415caf2663fab Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 14 May 2024 12:53:06 -0700 Subject: [PATCH 10/10] Addressing review comments --- .../dynamo/conversion/aten_ops_converters.py | 5 +++-- .../dynamo/conversion/impl/select.py | 20 ++++--------------- tests/py/dynamo/conversion/harness.py | 3 --- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 60e0ab4e9c..5493075dce 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,13 +691,14 @@ def aten_ops_clamp( ) +@dynamo_tensorrt_converter(torch.ops.aten.scatter.src) +@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) @enforce_tensor_types( { 0: (TRTTensor,), + 2: (TRTTensor,), } ) -@dynamo_tensorrt_converter(torch.ops.aten.scatter.src) -@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) def aten_ops_scatter( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index bd1257a8d1..f9351346a6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -21,7 +21,6 @@ set_layer_name, ) from torch_tensorrt.fx.types import Shape, TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -405,26 +404,15 @@ def scatter( ) -> TRTTensor: input_shape = input.shape index_shape = index.shape - index_shape_list = list(index.shape) - if not (isinstance(index, TRTTensor)): - if isinstance(index, torch.Tensor): - if index.dtype == torch.int64: - index = index.to(torch.int32) - elif isinstance(index, np.ndarray): - if index.dtype == np.int64: - index = index.astype(np.int32) - index = get_trt_tensor(ctx, index, f"_index_tensor") + index_shape_list = list(index_shape) + if index.dtype == trt.int64: + index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") dim = get_positive_dim(dim, len(input_shape)) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" - src_tensor = src # scatter.value if isinstance(src, int) or isinstance(src, float): src_tensor = get_trt_tensor( - ctx, src * torch.ones(index_shape_list), name + "_value_tensor" + ctx, src * np.ones(index_shape_list), name + "_value_tensor" ) src_tensor = cast_trt_tensor( ctx, src_tensor, input.dtype, name + "_cast_value_tensor" diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index ff90115550..8f55ce3fb6 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -1,4 +1,3 @@ -import copy import logging import time import unittest @@ -256,8 +255,6 @@ def run_test( dtype_32bit = ( torch.float32 if (input.dtype == torch.float64) else torch.int32 ) - # should we modify graph here to insert clone nodes? - # ideally not required trt_inputs = ( list(trt_inputs[:num_input]) + [