-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support aten.as_strided converter #2735
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from torch_tensorrt.dynamo.conversion import impl | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.dynamo.conversion.converter_utils import ( | ||
flatten_dims, | ||
get_positive_dim, | ||
get_trt_tensor, | ||
) | ||
|
@@ -259,3 +260,61 @@ def flip( | |
) | ||
set_layer_name(layer, target, name, source_ir) | ||
return layer.get_output(0) | ||
|
||
|
||
def as_strided( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
size: Sequence[int], | ||
stride: Sequence[int], | ||
storage_offset: Optional[int], | ||
) -> TRTTensor: | ||
# Ensure storage_offset is an integer before passing to nested | ||
if storage_offset is None: | ||
storage_offset = 0 | ||
|
||
flatten_shape = flatten_dims(input, 0, -1) | ||
flatten_output = impl.shuffle.reshape( | ||
ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape | ||
) | ||
|
||
indices = [] | ||
|
||
# Recursive function to compute indices for as_strided operation | ||
def nested( | ||
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int | ||
) -> None: | ||
if ( | ||
dim == rank | ||
): # If the current dimension equals the rank, append the computed index | ||
indices.append(current) | ||
return | ||
for i in range(size[dim]): # Recursively compute indices across dimensions | ||
nested( | ||
rank, size, stride, current + stride[dim] * i, dim + 1 | ||
) # Calculate the index for the current dimension and recursively explore further dimensions | ||
|
||
nested(len(size), size, stride, storage_offset, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above comment |
||
|
||
indices = np.array(indices, dtype=np.int32) | ||
|
||
indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices") | ||
|
||
# Use gather to reorder elements based on computed indices | ||
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0) | ||
gather_output = gather_layer.get_output(0) | ||
|
||
# Reshape the gathered tensor to the desired size | ||
reshape_output = impl.shuffle.reshape( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_reshape_gather_output", | ||
gather_output, | ||
tuple(size), | ||
) | ||
|
||
return reshape_output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
from torch_tensorrt import Input | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestAsStridedConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
( | ||
(5, 5), | ||
(2, 3), | ||
(1, 2), | ||
0, | ||
), | ||
( | ||
(5, 5), | ||
(2, 3), | ||
(2, 2), | ||
1, | ||
), | ||
( | ||
(20, 20), | ||
(2, 3, 2), | ||
(2, 2, 2), | ||
0, | ||
), | ||
( | ||
(8, 8, 8), | ||
(2, 2, 3), | ||
(1, 2, 2), | ||
1, | ||
), | ||
( | ||
(200, 200, 200), | ||
(9, 9, 3, 2), | ||
(2, 2, 2, 3), | ||
1, | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add more tests that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your suggestions! I've added more test cases that include scenarios where the I have a question regarding when the output shape contains zero. For instance, when using I attempted to implement functionality to return an empty tensor with the same shape, as shown below:
However, this approach leads to the following errors:
Do you have any suggestions on how to handle returning tensors with an output size containing zero? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, it looks like the output doesn't support empty in torch-trt. I think it's ok to add a validator like you have done, but need to remove the duplicated decorator: https://github.com/pytorch/TensorRT/pull/2735/files#diff-9cbf43cfb62cdf682eef77f6fd5cabc488bb5a91ff09bfee5d593827542b2476R2203 |
||
( | ||
(10, 25, 12), | ||
(3, 7, 3), | ||
(2, 1, 3), | ||
1, | ||
), | ||
( | ||
(10, 25, 12), | ||
(3, 7, 3), | ||
(2, 0, 3), | ||
1, | ||
), | ||
( | ||
(10, 25, 12, 100), | ||
(6, 5, 7, 10), | ||
(0, 0, 0, 0), | ||
0, | ||
), | ||
( | ||
(10, 25, 12, 100), | ||
(6, 5, 7, 10), | ||
(0, 0, 0, 0), | ||
1, | ||
), | ||
] | ||
) | ||
def test_as_strided( | ||
self, | ||
input_shape, | ||
output_size, | ||
stride, | ||
storage_offset=0, | ||
): | ||
class TestModule(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.ops.aten.as_strided.default( | ||
x, output_size, stride, storage_offset | ||
) | ||
|
||
inputs = [torch.randn(input_shape)] | ||
self.run_test( | ||
TestModule(), | ||
inputs, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function appears to only perform work when
dim == rank
.Would it be possible to instead replace with an action that only occurs when
dim == rank
, for instance something like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review and the suggestion to simplify the function.
The proposed list comprehension indeed simplifies the index computation for cases where we are only dealing with the last dimension. However, the recursive function was designed to handle multi-dimensional tensors where each dimension could potentially affect the computed index due to its stride and size.