Skip to content

Commit

Permalink
feat: support aten.as_strided converter
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed May 7, 2024
1 parent 4142d3f commit 9bc1e0b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,26 @@ def aten_ops_linear(
)


@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.as_strided(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, 0),
)


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)
Expand Down
58 changes: 58 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
flatten_dims,
get_positive_dim,
get_trt_tensor,
)
Expand Down Expand Up @@ -259,3 +261,59 @@ def flip(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def as_strided(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
size: Sequence[int],
stride: Sequence[int],
storage_offset: int,
) -> TRTTensor:
assert len(size) == len(stride), "size and stride shapes must be the same"

flatten_shape = flatten_dims(input, 0, -1)
flatten_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
)

indices = []

# Recursive function to compute indices for as_strided operation
def nested(
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
) -> None:
if (
dim == rank
): # If the current dimension equals the rank, append the computed index
indices.append(current)
return
for i in range(size[dim]): # Recursively compute indices across dimensions
nested(
rank, size, stride, current + stride[dim] * i, dim + 1
) # Calculate the index for the current dimension and recursively explore further dimensions

nested(len(size), size, stride, storage_offset, 0)

indices = torch.tensor(indices, dtype=torch.int)

indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices")

# Use gather to reorder elements based on computed indices
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
gather_output = gather_layer.get_output(0)

# Reshape the gathered tensor to the desired size
reshape_output = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape",
gather_output,
tuple(size),
)

return reshape_output
65 changes: 65 additions & 0 deletions tests/py/dynamo/conversion/test_as_strided_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestAsStridedConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 5),
(2, 3),
(1, 2),
0,
),
(
(5, 5),
(2, 3),
(2, 2),
1,
),
(
(20, 20),
(2, 3, 2),
(2, 2, 2),
0,
),
(
(8, 8, 8),
(2, 2, 3),
(1, 2, 2),
1,
),
(
(200, 200, 200),
(9, 9, 3, 2),
(2, 2, 2, 3),
1,
),
]
)
def test_as_strided(
self,
input_shape,
output_size,
stride,
storage_offset=0,
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.as_strided.default(
x, output_size, stride, storage_offset
)

inputs = [torch.randn(input_shape)]
self.run_test(
TestModule(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 9bc1e0b

Please sign in to comment.