diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 721a0a546c..d55bb5fec2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -801,6 +801,67 @@ def aten_ops_select( ) +def index_put_validator(node: Node) -> bool: + if args_bounds_check(node.args, 3, False): # Check if accumulate is valid + _LOGGER.debug("We do not support accumulate=True for aten.index_put operation") + accumulate_valid = False + else: + accumulate_valid = True + + # Retrieve input tensor's meta information + input_meta = node.args[0].meta.get("tensor_meta") + if not input_meta: + _LOGGER.warning( + "Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation." + ) + return False + + input_shape = input_meta.shape + input_num_dims = len(input_shape) + + # Check if broadcasting is valid + indices_num_dims = len(node.args[1]) + if indices_num_dims == input_num_dims: + broadcast_valid = True + else: + _LOGGER.debug( + "We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions." + ) + broadcast_valid = False + + # Return validation result + return accumulate_valid and broadcast_valid + + +@dynamo_tensorrt_converter( + torch.ops.aten.index_put.default, + capability_validator=index_put_validator, +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_index_put( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.index_put_converter( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args_bounds_check(args, 3, False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 0654985bb5..6d9a86f89b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -6,6 +6,7 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, @@ -410,7 +411,7 @@ def scatter( dim = get_positive_dim(dim, len(input_shape)) src_tensor = src # scatter.value - if isinstance(src, int) or isinstance(src, float): + if isinstance(src, (int, float)): src_tensor = get_trt_tensor( ctx, src * np.ones(index_shape_list), name + "_value_tensor" ) @@ -446,3 +447,41 @@ def gather( set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir) out = gather_layer.get_output(0) return out + + +def index_put_converter( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], + values: TRTTensor, + accumulate: bool = False, +) -> TRTTensor: + # Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors) + reshaped_indices = [] + for i, each_input in enumerate(indices): + if not isinstance(each_input, TRTTensor): + each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") + each_input = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{i}", + each_input, + (-1, 1), # Reshape to (N, 1) + ) + reshaped_indices.append(each_input) + + # Concatenate along the second dimension (columns) + indices_cat = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1 + ) + + scatter_layer = ctx.net.add_scatter( + input_tensor, indices_cat, values, trt.ScatterMode.ND + ) + scatter_layer.axis = 0 + set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir) + return scatter_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py new file mode 100644 index 0000000000..13e5853e51 --- /dev/null +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -0,0 +1,192 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIndexPutConverter(DispatchTestCase): + @parameterized.expand( + [ + param( + test_name="1d_indices_single", + source_tensor=torch.zeros([5], dtype=torch.int32), + indices_tensor=(torch.tensor([0], dtype=torch.int32),), + value_tensor=torch.tensor([1], dtype=torch.int32), + ), + param( + test_name="1d_indices_multiple", + source_tensor=torch.zeros([5], dtype=torch.int32), + indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), + value_tensor=torch.tensor([1, 3], dtype=torch.int32), + ), + param( + test_name="2d_indices_single", + source_tensor=torch.zeros([5, 5], dtype=torch.int32), + indices_tensor=( + torch.tensor([2], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ), + value_tensor=torch.tensor([3], dtype=torch.int32), + ), + param( + test_name="2d_indices_multiple", + source_tensor=torch.zeros([5, 5], dtype=torch.int32), + indices_tensor=( + torch.tensor([0, 2, 2], dtype=torch.int32), + torch.tensor([2, 0, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([1, 3, 4], dtype=torch.int32), + ), + param( + test_name="3d_indices_single", + source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), + indices_tensor=( + torch.tensor([1], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + ), + value_tensor=torch.tensor([7], dtype=torch.int32), + ), + param( + test_name="3d_indices_multiple", + source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), + indices_tensor=( + torch.tensor([0, 1, 1], dtype=torch.int32), + torch.tensor([1, 2, 1], dtype=torch.int32), + torch.tensor([2, 0, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([5, 7, 2], dtype=torch.int32), + ), + param( + test_name="4d_indices_single", + source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), + indices_tensor=( + torch.tensor([1], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ), + value_tensor=torch.tensor([5], dtype=torch.int32), + ), + param( + test_name="4d_indices_multiple", + source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5, 7], dtype=torch.int32), + ), + param( + test_name="negative_indices", + source_tensor=torch.zeros([5, 5], dtype=torch.int32), + indices_tensor=( + torch.tensor([-1, -2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([1, 3], dtype=torch.int32), + ), + param( + test_name="mixed_indices", + source_tensor=torch.zeros([4, 4], dtype=torch.int32), + indices_tensor=( + torch.tensor([0, 1, -1, -2], dtype=torch.int32), + torch.tensor([0, -1, 2, 1], dtype=torch.int32), + ), + value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.int32), + ), + param( + test_name="1d_indices_float", + source_tensor=torch.zeros([5], dtype=torch.float32), + indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), + ), + param( + test_name="2d_indices_float", + source_tensor=torch.zeros([5, 5], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), + ), + param( + test_name="3d_indices_float", + source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), + ), + param( + test_name="4d_indices_float", + source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), + ), + # param( + # test_name="3d_indices_float_broadcase_index", + # source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32), + # indices_tensor=( + # torch.tensor([0,1], dtype=torch.int32), + # torch.tensor([0,1], dtype=torch.int32), + # ), + # value_tensor=torch.tensor([10], dtype = torch.int32), + # ), + # param( + # test_name="2d_indices_accumulate_True", + # source_tensor=torch.zeros([5, 5], dtype=torch.int32), + # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), + # value_tensor=torch.tensor([1, 2], dtype=torch.int32), + # accumulate=True, + # ), + # param( + # test_name="3d_indices_accumulate_True", + # source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), + # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)), + # value_tensor=torch.tensor([1, 2], dtype=torch.int32), + # accumulate=True, + # ), + # param( + # test_name="4d_indices_accumulate_True", + # source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), + # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), + # value_tensor=torch.tensor([1, 2], dtype=torch.int32), + # accumulate=True, + # ), + ] + ) + def test_index_put( + self, test_name, source_tensor, indices_tensor, value_tensor, accumulate=False + ): + @torch._dynamo.assume_constant_result + def get_indices_tensor(): + return indices_tensor + + class TestIndexPut(torch.nn.Module): + def forward(self, source_tensor, value_tensor): + indices_tensor_const = get_indices_tensor() + return torch.ops.aten.index_put.default( + source_tensor, indices_tensor_const, value_tensor, accumulate + ) + + self.run_test( + TestIndexPut(), + inputs=[source_tensor, value_tensor], + enable_passes=True, + use_dynamo_tracer=True, + ) + + +if __name__ == "__main__": + run_tests()