Skip to content

Commit

Permalink
feat: support aten.as_strided converter (#2735)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed May 22, 2024
1 parent 763a4a1 commit 9341e9b
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,40 @@ def aten_ops_tile(
)


def zero_output_validator(node: Node) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
)
return False
else:
return True


@dynamo_tensorrt_converter(
torch.ops.aten.as_strided.default,
capability_validator=zero_output_validator,
)
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.as_strided(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
Expand Down
59 changes: 59 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
flatten_dims,
get_positive_dim,
get_trt_tensor,
)
Expand Down Expand Up @@ -259,3 +260,61 @@ def flip(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def as_strided(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
size: Sequence[int],
stride: Sequence[int],
storage_offset: Optional[int],
) -> TRTTensor:
# Ensure storage_offset is an integer before passing to nested
if storage_offset is None:
storage_offset = 0

flatten_shape = flatten_dims(input, 0, -1)
flatten_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape
)

indices = []

# Recursive function to compute indices for as_strided operation
def nested(
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
) -> None:
if (
dim == rank
): # If the current dimension equals the rank, append the computed index
indices.append(current)
return
for i in range(size[dim]): # Recursively compute indices across dimensions
nested(
rank, size, stride, current + stride[dim] * i, dim + 1
) # Calculate the index for the current dimension and recursively explore further dimensions

nested(len(size), size, stride, storage_offset, 0)

indices = np.array(indices, dtype=np.int32)

indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices")

# Use gather to reorder elements based on computed indices
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
gather_output = gather_layer.get_output(0)

# Reshape the gathered tensor to the desired size
reshape_output = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gather_output",
gather_output,
tuple(size),
)

return reshape_output
89 changes: 89 additions & 0 deletions tests/py/dynamo/conversion/test_as_strided_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestAsStridedConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 5),
(2, 3),
(1, 2),
0,
),
(
(5, 5),
(2, 3),
(2, 2),
1,
),
(
(20, 20),
(2, 3, 2),
(2, 2, 2),
0,
),
(
(8, 8, 8),
(2, 2, 3),
(1, 2, 2),
1,
),
(
(200, 200, 200),
(9, 9, 3, 2),
(2, 2, 2, 3),
1,
),
(
(10, 25, 12),
(3, 7, 3),
(2, 1, 3),
1,
),
(
(10, 25, 12),
(3, 7, 3),
(2, 0, 3),
1,
),
(
(10, 25, 12, 100),
(6, 5, 7, 10),
(0, 0, 0, 0),
0,
),
(
(10, 25, 12, 100),
(6, 5, 7, 10),
(0, 0, 0, 0),
1,
),
]
)
def test_as_strided(
self,
input_shape,
output_size,
stride,
storage_offset=0,
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.as_strided.default(
x, output_size, stride, storage_offset
)

inputs = [torch.randn(input_shape)]
self.run_test(
TestModule(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 9341e9b

Please sign in to comment.