Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.as_strided converter #2735

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,40 @@ def aten_ops_tile(
)


def zero_output_validator(node: Node) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
)
return False
else:
return True


@dynamo_tensorrt_converter(
torch.ops.aten.as_strided.default,
capability_validator=zero_output_validator,
)
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.as_strided(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
Expand Down
59 changes: 59 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
flatten_dims,
get_positive_dim,
get_trt_tensor,
)
Expand Down Expand Up @@ -259,3 +260,61 @@ def flip(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def as_strided(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
size: Sequence[int],
stride: Sequence[int],
storage_offset: Optional[int],
) -> TRTTensor:
# Ensure storage_offset is an integer before passing to nested
if storage_offset is None:
storage_offset = 0

flatten_shape = flatten_dims(input, 0, -1)
flatten_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape
)

indices = []

# Recursive function to compute indices for as_strided operation
def nested(
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
) -> None:
if (
dim == rank
): # If the current dimension equals the rank, append the computed index
indices.append(current)
return
for i in range(size[dim]): # Recursively compute indices across dimensions
nested(
rank, size, stride, current + stride[dim] * i, dim + 1
) # Calculate the index for the current dimension and recursively explore further dimensions
Comment on lines +286 to +298
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears to only perform work when dim == rank.

Would it be possible to instead replace with an action that only occurs when dim == rank, for instance something like:

indices = [(storage_offset + stride[-1] * i) for i in range(size[-1])]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review and the suggestion to simplify the function.

The proposed list comprehension indeed simplifies the index computation for cases where we are only dealing with the last dimension. However, the recursive function was designed to handle multi-dimensional tensors where each dimension could potentially affect the computed index due to its stride and size.


nested(len(size), size, stride, storage_offset, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment


indices = np.array(indices, dtype=np.int32)

indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices")

# Use gather to reorder elements based on computed indices
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
gather_output = gather_layer.get_output(0)

# Reshape the gathered tensor to the desired size
reshape_output = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gather_output",
gather_output,
tuple(size),
)

return reshape_output
89 changes: 89 additions & 0 deletions tests/py/dynamo/conversion/test_as_strided_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestAsStridedConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 5),
(2, 3),
(1, 2),
0,
),
(
(5, 5),
(2, 3),
(2, 2),
1,
),
(
(20, 20),
(2, 3, 2),
(2, 2, 2),
0,
),
(
(8, 8, 8),
(2, 2, 3),
(1, 2, 2),
1,
),
(
(200, 200, 200),
(9, 9, 3, 2),
(2, 2, 2, 3),
1,
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more tests that size or stride contains 0, and/or the input is not a square/cube?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions! I've added more test cases that include scenarios where the stride contains 0 and the input is not a square/cube. However, I encountered issues as we don't support output tensors with zero-sized dimensions for this operation. Therefore, I've added a validator to handle such cases.

I have a question regarding when the output shape contains zero. For instance, when using torch.as_strided(x, (0, 1), (2, 1), 0), it returns a tensor with a zero-sized dimension.
image

I attempted to implement functionality to return an empty tensor with the same shape, as shown below:

if 0 in size:
    empty_out = np.empty(size, dtype=np.float64)
    empty_out = get_trt_tensor(ctx, empty_out, f"{name}_empty")
    print(empty_out.dtype)
    return empty_out

# or

if 0 in size:
    empty_out = np.zeros(size, dtype=np.float64)
    empty_out = get_trt_tensor(ctx, empty_out, f"{name}_empty")
    print(empty_out.dtype)
    return empty_out

However, this approach leads to the following errors:

INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 134, GPU 1262 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1750, GPU +289, now: CPU 2019, GPU 1551 (MiB)
ERROR:torch_tensorrt [TensorRT Conversion Context]:3: [constantLayer.cpp::setWeights::40] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/layers/constantLayer.cpp::setWeights::40, condition: !weights.values == !weights.count )
DataType.???
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1607] Error Code 2: Internal Error (Assertion !isPartialWork(p.second.outputExtents) failed. )
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2270] Error Code 2: Internal Error (Assertion !isPartialWork(status.outputExtents) failed. )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:01.714407
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2270] Error Code 2: Internal Error (Assertion !isPartialWork(status.outputExtents) failed. )
F

Do you have any suggestions on how to handle returning tensors with an output size containing zero?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, it looks like the output doesn't support empty in torch-trt. I think it's ok to add a validator like you have done, but need to remove the duplicated decorator: https://github.com/pytorch/TensorRT/pull/2735/files#diff-9cbf43cfb62cdf682eef77f6fd5cabc488bb5a91ff09bfee5d593827542b2476R2203

(
(10, 25, 12),
(3, 7, 3),
(2, 1, 3),
1,
),
(
(10, 25, 12),
(3, 7, 3),
(2, 0, 3),
1,
),
(
(10, 25, 12, 100),
(6, 5, 7, 10),
(0, 0, 0, 0),
0,
),
(
(10, 25, 12, 100),
(6, 5, 7, 10),
(0, 0, 0, 0),
1,
),
]
)
def test_as_strided(
self,
input_shape,
output_size,
stride,
storage_offset=0,
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.as_strided.default(
x, output_size, stride, storage_offset
)

inputs = [torch.randn(input_shape)]
self.run_test(
TestModule(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading