diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 15a993668b..11a213551d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -800,6 +800,40 @@ def aten_ops_tile( ) +def zero_output_validator(node: Node) -> bool: + if 0 in node.args[1]: + _LOGGER.debug( + f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." + ) + return False + else: + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.as_strided.default, + capability_validator=zero_output_validator, +) +@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) +def aten_ops_as_strided( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.as_strided( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + size=args[1], + stride=args[2], + storage_offset=args_bounds_check(args, 3, None), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 61d71fe9a0..139ecc1149 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -8,6 +8,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + flatten_dims, get_positive_dim, get_trt_tensor, ) @@ -259,3 +260,61 @@ def flip( ) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def as_strided( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + size: Sequence[int], + stride: Sequence[int], + storage_offset: Optional[int], +) -> TRTTensor: + # Ensure storage_offset is an integer before passing to nested + if storage_offset is None: + storage_offset = 0 + + flatten_shape = flatten_dims(input, 0, -1) + flatten_output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape + ) + + indices = [] + + # Recursive function to compute indices for as_strided operation + def nested( + rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int + ) -> None: + if ( + dim == rank + ): # If the current dimension equals the rank, append the computed index + indices.append(current) + return + for i in range(size[dim]): # Recursively compute indices across dimensions + nested( + rank, size, stride, current + stride[dim] * i, dim + 1 + ) # Calculate the index for the current dimension and recursively explore further dimensions + + nested(len(size), size, stride, storage_offset, 0) + + indices = np.array(indices, dtype=np.int32) + + indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices") + + # Use gather to reorder elements based on computed indices + gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0) + gather_output = gather_layer.get_output(0) + + # Reshape the gathered tensor to the desired size + reshape_output = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_gather_output", + gather_output, + tuple(size), + ) + + return reshape_output diff --git a/tests/py/dynamo/conversion/test_as_strided_aten.py b/tests/py/dynamo/conversion/test_as_strided_aten.py new file mode 100644 index 0000000000..ba723bf4f9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_as_strided_aten.py @@ -0,0 +1,89 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestAsStridedConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + (5, 5), + (2, 3), + (1, 2), + 0, + ), + ( + (5, 5), + (2, 3), + (2, 2), + 1, + ), + ( + (20, 20), + (2, 3, 2), + (2, 2, 2), + 0, + ), + ( + (8, 8, 8), + (2, 2, 3), + (1, 2, 2), + 1, + ), + ( + (200, 200, 200), + (9, 9, 3, 2), + (2, 2, 2, 3), + 1, + ), + ( + (10, 25, 12), + (3, 7, 3), + (2, 1, 3), + 1, + ), + ( + (10, 25, 12), + (3, 7, 3), + (2, 0, 3), + 1, + ), + ( + (10, 25, 12, 100), + (6, 5, 7, 10), + (0, 0, 0, 0), + 0, + ), + ( + (10, 25, 12, 100), + (6, 5, 7, 10), + (0, 0, 0, 0), + 1, + ), + ] + ) + def test_as_strided( + self, + input_shape, + output_size, + stride, + storage_offset=0, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.as_strided.default( + x, output_size, stride, storage_offset + ) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()