diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..5493075dce 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,6 +691,26 @@ def aten_ops_clamp( ) +@dynamo_tensorrt_converter(torch.ops.aten.scatter.src) +@dynamo_tensorrt_converter(torch.ops.aten.scatter.value) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.scatter( + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3] + ) + + @dynamo_tensorrt_converter(torch.ops.aten.select.int) def aten_ops_select( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 2ec6420e0b..f9351346a6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -390,3 +390,41 @@ def index_select( set_layer_name(gather_layer, target, f"{name}_gather", source_ir) return gather_layer.get_output(0) + + +def scatter( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: Union[TRTTensor, np.ndarray, torch.Tensor], + src: Union[TRTTensor, int, float], +) -> TRTTensor: + input_shape = input.shape + index_shape = index.shape + index_shape_list = list(index_shape) + if index.dtype == trt.int64: + index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") + dim = get_positive_dim(dim, len(input_shape)) + src_tensor = src + # scatter.value + if isinstance(src, int) or isinstance(src, float): + src_tensor = get_trt_tensor( + ctx, src * np.ones(index_shape_list), name + "_value_tensor" + ) + src_tensor = cast_trt_tensor( + ctx, src_tensor, input.dtype, name + "_cast_value_tensor" + ) + # scatter.src + elif not (isinstance(src, TRTTensor)): + src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") + + scatter_layer = ctx.net.add_scatter( + input, index, src_tensor, trt.ScatterMode.ELEMENT + ) + scatter_layer.axis = dim + set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) + out = scatter_layer.get_output(0) + return out diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 7bed68a11e..8f55ce3fb6 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -1,5 +1,3 @@ -# type: ignore - import logging import time import unittest @@ -50,16 +48,20 @@ def setUp(self): def run_test( self, mod, - inputs, + fx_inputs, + trt_interpreter_inputs, interpreter, rtol, atol, check_dtype=True, ): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) + cuda_fx_inputs = [] + cuda_trt_inputs = [] + for i in trt_interpreter_inputs: + cuda_trt_inputs.append(i.cuda()) + for i in fx_inputs: + cuda_fx_inputs.append(i.cuda()) mod.eval() start = time.perf_counter() @@ -73,13 +75,13 @@ def run_test( ) mod = mod.cuda() - ref_outputs = mod(*cuda_inputs) + ref_outputs = mod(*cuda_fx_inputs) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - outputs = trt_mod(*cuda_inputs) + outputs = trt_mod(*cuda_trt_inputs) end_event.record() torch.cuda.synchronize() _LOGGER.info( @@ -220,6 +222,7 @@ def run_test( check_dtype=True, use_dynamo_tracer=False, enable_passes=False, + int32_reqd=False, ): mod.eval() mod = self.generate_graph( @@ -237,6 +240,30 @@ def run_test( debug=True, ) + num_inputs = len(inputs) + trt_inputs = inputs + dtype_to_change = [] + if int32_reqd: + dtype_to_change = [torch.int64, torch.float64] + else: + dtype_to_change = [ + torch.float64, + ] + for num_input in range(num_inputs): + input = inputs[num_input] + if input.dtype in dtype_to_change: + dtype_32bit = ( + torch.float32 if (input.dtype == torch.float64) else torch.int32 + ) + trt_inputs = ( + list(trt_inputs[:num_input]) + + [ + input.to(dtype_32bit), + ] + + list(trt_inputs[num_input + 1 :]) + ) + + trt_input_specs = [Input.from_tensor(i) for i in trt_inputs] input_specs = [Input.from_tensor(i) for i in inputs] output_dtypes = None @@ -254,13 +281,15 @@ def run_test( interp = TRTInterpreter( mod, - input_specs, + trt_input_specs, output_dtypes=output_dtypes, compilation_settings=compilation_settings, ) + super().run_test( mod, inputs, + trt_inputs, interp, rtol, atol, @@ -335,4 +364,4 @@ def run_test_with_dynamic_shape( # Since the lowering is based on optimal shape. We need to test with # different shape(for ex. max shape) for testing dynamic shape inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] - super().run_test(mod, inputs_max, interp, rtol, atol) + super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol) diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py new file mode 100644 index 0000000000..85b7094a5c --- /dev/null +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -0,0 +1,177 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestScatterValueConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + "scatter_zero_dim_indexOne_constant_value", + 0, + torch.tensor([[0, 1, 2, 0]]), + 1, + ), + ( + "scatter_zero_dim_indexTwo_constant_value", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ( + "scatter_one_dim_indexOne_constant_value", + 1, + torch.tensor([[0, 1, 2, 0]]), + 1, + ), + ( + "scatter_one_dim_indexTwo_costant_value", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ] + ) + def test_scatter_index_constant(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + self.run_test(TestModule(), inputs, int32_reqd=True) + + @parameterized.expand( + [ + ("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1), + ( + "scatter_zero_dim_indexTwo_value", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1), + ( + "scatter_one_dim_indexTwo_value", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + 1, + ), + ] + ) + def test_scatter_index_input(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test(TestModule(), inputs, int32_reqd=True) + + +class TestScatterSrcConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + "scatter_zero_dim_indexOne_src", + 0, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_zero_dim_indexTwo_src", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexOne_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexOne_constant_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_constant_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ] + ) + def test_scatter_index_constant(self, _, dim, index, src): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.scatter.src(input, dim, index, src) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + scatter = TestModule() + self.run_test(TestModule(), inputs, int32_reqd=True) + + @parameterized.expand( + [ + ( + "scatter_zero_dim_indexOne_constant_src", + 0, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_zero_dim_indexTwo_constant_src", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexOne_constant_src", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ( + "scatter_one_dim_indexTwo_constant_src", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + ), + ] + ) + def test_scatter_index_input(self, _, dim, index, src): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.src(input, dim, index, src) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test(TestModule(), inputs, int32_reqd=True) + + +if __name__ == "__main__": + run_tests()