Skip to content

Commit

Permalink
adding test cases and correcting empty__stride decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jun 6, 2024
1 parent b4913b6 commit 1360961
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def select_scatter_decomposition(
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
empty_size = args[0]
empty_stride = args[1]
return torch.as_strided(torch.empty(empty_size), empty_stride)
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)


def get_decompositions(
Expand Down
95 changes: 95 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from parameterized import parameterized

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing

Expand Down Expand Up @@ -868,6 +870,99 @@ def forward(self, x, src, dim, index):
f"Select_scatter TRT outputs don't match with the original model.",
)

empty_ops = [
(
"empty_stride_one_dimension_firstcase",
[5, 5],
[1, 2],
None,
),
(
"empty_stride_two_dimension_secondcase",
[5, 5],
[2, 2],
None,
),
(
"empty_three_dimension",
[8, 8, 8],
[1, 2, 3],
torch.int32,
),
]

@parameterized.expand(
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
)
def test_empty_stride(self, _, shape_or_input, stride, data_type):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
# The add operation is added otherwise it returns an empty graph post lowering passes
add_tensor = torch.ops.aten.add(input[0], input[0])
shape_or_input[0] = input[0].shape[0]
empty_strided = torch.ops.aten.empty_strided.default(
shape_or_input, stride, dtype=data_type
)
add_tensor = empty_strided.cuda() + add_tensor
return add_tensor

# Operations expected to be included in the traced graph after decompositions
unexpected_ops = {
torch.ops.aten.empty_strided.default,
torch.ops.aten.empty_permuted.default,
}
expected_ops = {torch.ops.aten.add.Tensor}

input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()]
inputs = [input]

fx_graph = torch.fx.symbolic_trace(TestModule())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=2,
)

torch._dynamo.reset()

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

self.assertEqual(
optimized_model_results.shape,
torch_model_results.shape,
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 1360961

Please sign in to comment.