From f0effc0bd8d04f8ff2a7e5d7a7092a116b54ffdd Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Tue, 17 Jun 2025 16:53:14 -0700 Subject: [PATCH] Set pyre-strict for passes unit tests. (#11740) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11740 This diff fixes this test failure https://www.internalfb.com/intern/test/844425134145806 ...and also ensures same problem won't pop up for other unit tests. Reviewed By: zonglinpeng Differential Revision: D76767018 --- backends/cadence/aot/TARGETS | 16 +- .../aot/tests/test_fusion_ops_passes.py | 216 ++++++++-------- .../cadence/aot/tests/test_memory_passes.py | 66 +++-- .../cadence/aot/tests/test_pass_filter.py | 23 +- .../aot/tests/test_remove_ops_passes.py | 65 ++--- .../aot/tests/test_reorder_ops_passes.py | 97 ++++---- .../aot/tests/test_replace_ops_passes.py | 231 ++++++++++-------- .../aot/tests/test_simplify_ops_passes.py | 12 +- backends/cadence/aot/typing_stubs.py | 22 ++ 9 files changed, 424 insertions(+), 324 deletions(-) create mode 100644 backends/cadence/aot/typing_stubs.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index a0de747cf3f..c3ca472147f 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -12,6 +12,7 @@ load( "CXX", ) load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") +load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension") oncall("odai_jarvis") @@ -275,7 +276,6 @@ python_library( "//executorch/exir/passes:spec_prop_pass", ], ) - python_library( name = "decompose_ops", srcs = [ @@ -293,6 +293,14 @@ python_library( ], ) +python_library( + name = "typing_stubs", + srcs = [ + "typing_stubs.py", + ], + typing = True, +) + python_unittest( name = "test_graph_builder", @@ -321,6 +329,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", ":compiler", + ":typing_stubs", ":replace_ops", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -344,6 +353,7 @@ python_unittest( ":compiler", ":decompose_ops", "//caffe2:torch", + ":typing_stubs", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot:pass_utils", @@ -363,6 +373,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", ":compiler", + ":typing_stubs", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:fuse_ops", @@ -384,6 +395,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", + ":typing_stubs", ":compiler", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -404,6 +416,7 @@ python_unittest( supports_static_listing = False, typing = True, deps = [ + ":typing_stubs", "fbsource//third-party/pypi/parameterized:parameterized", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -477,6 +490,7 @@ python_unittest( deps = [ ":compiler", ":memory_planning", + ":typing_stubs", ":ops_registrations", ":pass_utils", "//caffe2:torch", diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 30ea91bafb5..4d2e53ff264 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest -from typing import Final, List, Tuple +from typing import cast, Final, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -26,10 +26,10 @@ ) from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ProxyValue -from parameterized import parameterized +from executorch.exir.pass_base import PassResult, ProxyValue from torch import nn @@ -43,7 +43,7 @@ def check_op_counts( class TestFusionPasses(TestFusionPassesBase): - def test_fuse_mm_with_add(self): + def test_fuse_mm_with_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -55,7 +55,9 @@ def test_fuse_mm_with_add(self): output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module + converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -63,7 +65,7 @@ def test_fuse_mm_with_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_fuse_view_mm_view_add(self): + def test_fuse_view_mm_view_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) @@ -83,7 +85,9 @@ def test_fuse_view_mm_view_add(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -91,7 +95,7 @@ def test_fuse_view_mm_view_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_keep_view_mm_view_add(self): + def test_keep_view_mm_view_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) @@ -112,7 +116,8 @@ def test_keep_view_mm_view_add(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since z cannot be # broadcasted to the out of mm. @@ -122,7 +127,7 @@ def test_keep_view_mm_view_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1) - def test_fuse_mm_add_with_bias(self): + def test_fuse_mm_add_with_bias(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -136,7 +141,8 @@ def test_fuse_mm_add_with_bias(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -144,7 +150,7 @@ def test_fuse_mm_add_with_bias(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_keep_mm_add_with_multiple_users(self): + def test_keep_mm_add_with_multiple_users(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -161,7 +167,8 @@ def test_keep_mm_add_with_multiple_users(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since add has multiple # users. @@ -171,17 +178,19 @@ def test_keep_mm_add_with_multiple_users(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) - # TODO(matthiascremon): enable that pass with new flow + # TODO(matthiascremon) -> None: enable that pass with new flow @torch.no_grad() @unittest.expectedFailure - def test_legacy_conv_bn_fusion(self): + def test_legacy_conv_bn_fusion(self) -> None: class ModelConvBN(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size: int): + def __init__( + self, in_features: int, out_features: int, kernel_size: int + ) -> None: super().__init__() self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) self.bn = nn.BatchNorm1d(out_features) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.conv1d(x) return self.bn(y) @@ -189,8 +198,7 @@ def forward(self, x): x = torch.randn(1, 64, 4) graph_module = ( - compiler.export_to_executorch(model.eval(), (x,)) - .exported_program() + compiler.export_to_executorch_gen_etrecord(model.eval(), (x,)) .exported_program() .graph_module ) @@ -207,7 +215,7 @@ def forward(self, x): 0, ) - def test_permute_transpose_fusion(self): + def test_permute_transpose_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) permute = builder.call_operator( @@ -217,11 +225,10 @@ def test_permute_transpose_fusion(self): op=exir_ops.edge.aten.transpose_copy.int, args=(permute, 1, 0), ) - builder.output(output) + builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedTransposeOrPermuteOps()( - original_graph - ).graph_module + p = FuseCascadedTransposeOrPermuteOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that permute op was fused with transpose op self.assertEqual( @@ -231,7 +238,7 @@ def test_permute_transpose_fusion(self): count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0 ) - def test_view_fusion(self): + def test_view_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) view1 = builder.call_operator( @@ -243,16 +250,17 @@ def test_view_fusion(self): output = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10]) ) - builder.output(output) + builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedViewOps()(original_graph).graph_module + p = FuseCascadedViewOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that only one view op remains self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1 ) - def test_view_fusion_branched(self): + def test_view_fusion_branched(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) y = builder.call_operator( @@ -266,14 +274,15 @@ def test_view_fusion_branched(self): ) builder.output([z, t]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedViewOps()(original_graph).graph_module + p = FuseCascadedViewOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # z and t should be fused and y should be eliminated. self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2 ) - def test_force_quant_dequant_fusion(self): + def test_force_quant_dequant_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -287,11 +296,10 @@ def test_force_quant_dequant_fusion(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass( - force_quant_dequant_fusion=True - )(original_graph).graph_module + p = FuseQuantDequantToRequantizePass(force_quant_dequant_fusion=True) + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ @@ -302,7 +310,7 @@ def test_force_quant_dequant_fusion(self): }, ) - def test_no_replace_quant_permute_dequant_with_requantize(self): + def test_no_replace_quant_permute_dequant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -316,11 +324,11 @@ def test_no_replace_quant_permute_dequant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass( - force_quant_dequant_fusion=False - )(original_graph).graph_module + + p = FuseQuantDequantToRequantizePass(force_quant_dequant_fusion=False) + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ @@ -332,7 +340,7 @@ def test_no_replace_quant_permute_dequant_with_requantize(self): }, ) - def test_replace_quant_view_dequant_with_requantize(self): + def test_replace_quant_view_dequant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -346,11 +354,10 @@ def test_replace_quant_view_dequant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(view, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass()( - original_graph - ).graph_module + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ @@ -361,7 +368,7 @@ def test_replace_quant_view_dequant_with_requantize(self): }, ) - def test_replace_dequant_quant_with_requantize(self): + def test_replace_dequant_quant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) dequant = builder.call_operator( @@ -372,13 +379,13 @@ def test_replace_dequant_quant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(dequant, 4.5, 6, 0, 127, torch.int8), ) - builder.output(quant) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that dequant -> quant was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, @@ -387,7 +394,7 @@ def test_replace_dequant_quant_with_requantize(self): }, ) - def test_replace_dequant_permute_quant_with_requantize(self): + def test_replace_dequant_permute_quant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) dequant = builder.call_operator( @@ -401,13 +408,13 @@ def test_replace_dequant_permute_quant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(quant) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that dequant -> permute -> quant was replaced with permute -> requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, @@ -417,33 +424,33 @@ def test_replace_dequant_permute_quant_with_requantize(self): }, ) - def test_remove_nop_dequant_quant(self): - LEADING_DIMS: Final[int] = 12 - IN_DIM: Final[int] = 6 - OUT_DIM: Final[int] = 12 + def test_remove_nop_dequant_quant(self) -> None: + leading_dims = 12 + in_dim = 6 + out_dim = 12 builder = GraphBuilder() x = builder.placeholder( - "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) + "x", torch.randn(leading_dims, in_dim, dtype=torch.float32) ) quant1 = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(x, 4.5, 6, 0, 127, torch.int8), ) weights = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim, in_dim], 1) ) bias = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim], 1) ) weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) + op=exir_ops.edge.aten.full.default, args=([in_dim], 0) ) out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim], 1) ) out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) + op=exir_ops.edge.aten.full.default, args=([out_dim], 0) ) linear1 = builder.call_operator( op=exir_ops.edge.cadence.quantized_linear.default, @@ -488,12 +495,12 @@ def test_remove_nop_dequant_quant(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(linear2, 1.2, 3, 0, 127, torch.int8), ) - builder.output(dequant2) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([dequant2]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that one dequant/quant pair was removed from chain: # quant->linear->dequant->permute->quant->linear->dequant @@ -504,7 +511,7 @@ def test_remove_nop_dequant_quant(self): }, ) - def test_fuse_mul_into_dequant(self): + def test_fuse_mul_into_dequant(self) -> None: INPUT_SHAPE: Final[List[int]] = [4, 32] DEQUANT_SCALE: Final[float] = 1.5 FULL_VALUE: Final[float] = 3 @@ -523,14 +530,14 @@ def test_fuse_mul_into_dequant(self): op=exir_ops.edge.aten.mul.Tensor, args=(dequant, full), ) - builder.output(mul) - graph_module = FuseMulTensorIntoDequantPass()( - builder.get_graph_module() - ).graph_module + builder.output([mul]) + original_graph = builder.get_graph_module() + p = FuseMulTensorIntoDequantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, exir_ops.edge.aten.full.default: 0, @@ -539,7 +546,8 @@ def test_fuse_mul_into_dequant(self): ) # verify that the dequant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default @@ -547,7 +555,7 @@ def test_fuse_mul_into_dequant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE) - def test_fuse_mul_scalar_into_dequant(self): + def test_fuse_mul_scalar_into_dequant(self) -> None: dequant_scale = 0.006 mul_value = 0.3 @@ -565,14 +573,14 @@ def test_fuse_mul_scalar_into_dequant(self): op=exir_ops.edge.aten.mul.Scalar, args=(dequant, mul_value), ) - builder.output(mul_scalar) - graph_module = builder.get_graph_module() - - graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module + builder.output([mul_scalar]) + original_graph = builder.get_graph_module() + p = FuseMulScalarIntoDequantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, exir_ops.edge.aten.mul.Scalar: 0, @@ -580,7 +588,8 @@ def test_fuse_mul_scalar_into_dequant(self): ) # verify that the dequant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default @@ -588,7 +597,7 @@ def test_fuse_mul_scalar_into_dequant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, dequant_scale * mul_value) - def test_fuse_mul_into_quant(self): + def test_fuse_mul_into_quant(self) -> None: quant_scale = 1.5 mul_value = 10 @@ -606,14 +615,14 @@ def test_fuse_mul_into_quant(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(mul, quant_scale, 0, 0, 255, torch.uint8), ) - builder.output(quant) - graph_module = FuseMulTensorIntoQuantPass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseMulTensorIntoQuantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, exir_ops.edge.aten.full.default: 0, @@ -622,7 +631,8 @@ def test_fuse_mul_into_quant(self): ) # verify that the quant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -630,7 +640,7 @@ def test_fuse_mul_into_quant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, quant_scale * mul_value) - def test_fuse_then_transpose_pass(self): + def test_fuse_then_transpose_pass(self) -> None: # Create a graph with full -> transpose. builder = GraphBuilder() full_node = builder.call_operator( @@ -648,10 +658,10 @@ def test_fuse_then_transpose_pass(self): op=exir_ops.edge.aten.view_copy.default, args=(permute_node, (1, 6, 1)), ) - builder.output(view_node) - gm = builder.get_graph_module() + builder.output([view_node]) + original_graph = builder.get_graph_module() self.check_op_counts( - gm, + original_graph, expected_op_counts={ exir_ops.edge.aten.full.default: 1, exir_ops.edge.aten.transpose_copy.int: 1, @@ -661,7 +671,8 @@ def test_fuse_then_transpose_pass(self): ) # Check that the pass fuses the full with all other ops (transpose, permute, view). - gm_after_pass = FuseFullThenReshapePass()(gm).graph_module + p = FuseFullThenReshapePass() + gm_after_pass = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( gm_after_pass, expected_op_counts={ @@ -708,7 +719,7 @@ def _create_operator( else: raise ValueError(f"Unsupported op: {op}") - @parameterized.expand( + @expand( [ # transpose -> quant -> same transpose => fuse ( @@ -858,7 +869,7 @@ def test_fuse_transpose_permute_pairs( quant_op: torch._ops.OpOverload, expected_is_fused: bool, dims: Tuple[int, int, int] = (2, 3, 4), - ): + ) -> None: # Create a graph with transpose/permute -> quant -> transpose/permute. builder = GraphBuilder() x = builder.placeholder("x", torch.randn(dims)) @@ -911,7 +922,7 @@ def test_fuse_transpose_permute_pairs( expected_op_counts=expected_op_counts, ) - def test_fusion_for_forked_transposes(self): + def test_fusion_for_forked_transposes(self) -> None: # Create a graph with # transpose -> quant -> transpose. # -> quant -> transpose. @@ -946,7 +957,8 @@ def test_fusion_for_forked_transposes(self): ) # Fuse all the transpose ops. - gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module + p = FuseTransposeOrPermuteOpPairsPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module self.check_op_counts( gm_after_pass, expected_op_counts={ diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index b7616b047d3..73b0cba65ce 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import math import unittest -from typing import cast, Optional +from typing import cast, List, Optional import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -19,6 +19,7 @@ find_peak_memory_usage, ) from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.cadence.aot.utils import ( get_default_memory_config, MemoryConfig, @@ -27,7 +28,6 @@ from executorch.exir.memory_planning import collect_specs_from_nodes from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tests.models import MultiLayerPerceptron -from parameterized.parameterized import parameterized from torch.fx import GraphModule @@ -224,11 +224,11 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None: # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes. def run_memory_planning( self, - original, - opt_level=2, - mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy - alloc_graph_input=True, - alloc_graph_output=True, + original: GraphModule, + opt_level: int = 2, + mem_algo: int = 1, # greedy_by_size_for_offset_calculation_with_hierarchy + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, memory_config: Optional[MemoryConfig] = None, ) -> GraphModule: if memory_config is None: @@ -242,7 +242,7 @@ def run_memory_planning( alloc_graph_output=alloc_graph_output, )(graph_module).graph_module - @parameterized.expand( + @expand( [ [ [3, 6], # x_shape @@ -259,7 +259,11 @@ def run_memory_planning( ] ) def test_optimize_cat_on_placeholders( - self, x_shape, y_shape, concat_dim, alloc_graph_input + self, + x_shape: List[int], + y_shape: List[int], + concat_dim: int, + alloc_graph_input: bool, ) -> None: concat_shape = [x_shape[concat_dim] + y_shape[concat_dim], x_shape[1]] builder = GraphBuilder() @@ -294,7 +298,12 @@ def test_optimize_cat_on_placeholders( # "add_add_cat_model" : cat(x + 123, y + 456) # "add_add_cat_add_model": cat(x + 123, y + 456) + 789 def get_graph_module( - self, model_name, x_shape, y_shape, concated_shape, concat_dim + self, + model_name: str, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> GraphModule: builder = GraphBuilder() x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32)) @@ -346,7 +355,7 @@ def get_graph_module( raise ValueError(f"Unknown model name {model_name}") - @parameterized.expand( + @expand( [ ( "outermost", @@ -363,10 +372,14 @@ def get_graph_module( 1, # concat dim ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_optimized( - self, _, x_shape, y_shape, concated_shape, concat_dim + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> None: original = self.get_graph_module( "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim @@ -379,7 +392,7 @@ def test_cat_optimized( self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ ( "non_outermost", @@ -389,10 +402,14 @@ def test_cat_optimized( 1, # concat dim ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_not_optimized( - self, _, x_shape, y_shape, concated_shape, concat_dim + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> None: original = self.get_graph_module( "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim @@ -404,7 +421,7 @@ def test_cat_not_optimized( self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ ( "aligned", @@ -423,10 +440,15 @@ def test_cat_not_optimized( 1, # expected cat nodes ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_not_graph_output( - self, _, x_shape, y_shape, concated_shape, concat_dim, expected_cat_nodes + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, + expected_cat_nodes: int, ) -> None: original = self.get_graph_module( "add_add_cat_add_model", x_shape, y_shape, concated_shape, concat_dim @@ -493,13 +515,13 @@ def test_optimize_cat_with_slice(self) -> None: self.assertEqual(count_node(graph_module, exir_ops.edge.aten.slice.Tensor), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ (True,), # alloc_graph_input (False,), # alloc_graph_input ], ) - def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input) -> None: + def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input: bool) -> None: x_shape = [5, 6] y_shape = [3, 6] concated_shape = [8, 6] diff --git a/backends/cadence/aot/tests/test_pass_filter.py b/backends/cadence/aot/tests/test_pass_filter.py index 21b004d4942..9bfd71556bd 100644 --- a/backends/cadence/aot/tests/test_pass_filter.py +++ b/backends/cadence/aot/tests/test_pass_filter.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest - from copy import deepcopy +from typing import Callable, Dict + from executorch.backends.cadence.aot import pass_utils from executorch.backends.cadence.aot.pass_utils import ( ALL_CADENCE_PASSES, @@ -23,24 +24,26 @@ class TestBase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # Before running each test, create a copy of _all_passes to later restore it after test. # This avoids messing up the original _all_passes when running tests. self._all_passes_original = deepcopy(ALL_CADENCE_PASSES) # Clear _all_passes to do a clean test. It'll be restored after each test in tearDown(). pass_utils.ALL_CADENCE_PASSES.clear() - def tearDown(self): + def tearDown(self) -> None: # Restore _all_passes to original state before test. pass_utils.ALL_CADENCE_PASSES = self._all_passes_original - def get_filtered_passes(self, filter_): + def get_filtered_passes( + self, filter_: Callable[[ExportPass], bool] + ) -> Dict[ExportPass, CadencePassAttribute]: return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)} # Test pass registration class TestPassRegistration(TestBase): - def test_register_cadence_pass(self): + def test_register_cadence_pass(self) -> None: pass_attr_O0 = CadencePassAttribute(opt_level=0) pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True) pass_attr_O1_all_backends = CadencePassAttribute( @@ -73,7 +76,7 @@ class DummyPass_Debug(ExportPass): # Test pass filtering class TestPassFiltering(TestBase): - def test_filter_none(self): + def test_filter_none(self) -> None: pass_attr_O0 = CadencePassAttribute(opt_level=0) pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) pass_attr_O1_all_backends = CadencePassAttribute( @@ -103,7 +106,7 @@ class DummyPass_O1_All_Backends(ExportPass): } self.assertEqual(O1_filter_passes, expected_passes) - def test_filter_debug(self): + def test_filter_debug(self) -> None: pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) pass_attr_O2 = CadencePassAttribute(opt_level=2) @@ -122,7 +125,7 @@ class DummyPass_O2(ExportPass): # chooses debug=False. self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2}) - def test_filter_all(self): + def test_filter_all(self) -> None: @register_cadence_pass(CadencePassAttribute(opt_level=1)) class DummyPass_O1(ExportPass): pass @@ -138,7 +141,7 @@ class DummyPass_O2(ExportPass): # passes with opt_level <= 0 self.assertEqual(debug_filter_passes, {}) - def test_filter_opt_level_None(self): + def test_filter_opt_level_None(self) -> None: pass_attr_O1 = CadencePassAttribute(opt_level=1) pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 012f109f313..5fe2848be94 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest -from typing import cast, Tuple +from typing import cast, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -34,21 +34,22 @@ RemoveZeroSizedCatArgsPass, RemoveZeroSizedConstantPadNd, ) +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from parameterized.parameterized import parameterized from pyre_extensions import none_throws from torch.fx.passes.infra.pass_base import PassResult class TestRemoveOpsPasses(unittest.TestCase): - @parameterized.expand( + + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_to_ops(self, shape: Tuple[int]): + def test_remove_to_ops(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) x = builder.call_operator( @@ -69,7 +70,7 @@ def test_remove_to_ops(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(7, 6, 5)], [(7, 6)], @@ -77,7 +78,7 @@ def test_remove_to_ops(self, shape: Tuple[int]): ] ) @torch.no_grad() - def test_remove_nop_add_op_pass(self, shape: Tuple[int]): + def test_remove_nop_add_op_pass(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) zeros = builder.call_operator( @@ -101,7 +102,7 @@ def test_remove_nop_add_op_pass(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(7, 6, 5)], [(7, 6)], @@ -109,7 +110,7 @@ def test_remove_nop_add_op_pass(self, shape: Tuple[int]): ] ) @torch.no_grad() - def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): + def test_remove_nop_mul_op_pass(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) zeros = builder.call_operator( @@ -133,13 +134,13 @@ def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_alias_copy(self, shape: Tuple[int]): + def test_remove_alias_copy(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) alias = builder.call_operator( @@ -155,13 +156,13 @@ def test_remove_alias_copy(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_detach_copy(self, shape: Tuple[int]): + def test_remove_detach_copy(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) detach = builder.call_operator( @@ -177,7 +178,7 @@ def test_remove_detach_copy(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3), (0, 0)], ] @@ -185,7 +186,7 @@ def test_remove_detach_copy(self, shape: Tuple[int]): @torch.no_grad() def test_remove_zero_sized_constant_pad_nd( self, shape: Tuple[int], padding: Tuple[int] - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) pad = builder.call_operator( @@ -201,7 +202,7 @@ def test_remove_zero_sized_constant_pad_nd( 0, ) - def test_remove_expand(self): + def test_remove_expand(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([2, 3, 5], dtype=torch.float32)) expand = builder.call_operator( @@ -216,7 +217,7 @@ def test_remove_expand(self): count_node(graph_after_passes, exir_ops.edge.aten.expand_copy.default), 0 ) - def test_remove_zero_arg_cat(self): + def test_remove_zero_arg_cat(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([1, 0, 3, 5], dtype=torch.float32)) y = builder.placeholder("y", torch.randn([2, 0, 3, 5], dtype=torch.float32)) @@ -232,18 +233,19 @@ def test_remove_zero_arg_cat(self): count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) - def test_remove_clone(self): + def test_remove_clone(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,)) builder.output([clone]) original = builder.get_graph_module() - graph_after_passes = RemoveCloneOpPass()(original).graph_module + p = RemoveCloneOpPass() + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.clone.default), 0 ) - def test_remove_contiguous(self): + def test_remove_contiguous(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) contiguous = builder.call_operator( @@ -251,19 +253,20 @@ def test_remove_contiguous(self): ) builder.output([contiguous]) original = builder.get_graph_module() - graph_after_passes = RemoveContiguousOpPass()(original).graph_module + p = RemoveContiguousOpPass() + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0 ) - @parameterized.expand( + @expand( [ [(3, 5), [3, 5]], [(1,), [-1]], ] ) @torch.no_grad() - def test_remove_nop_view(self, shape, new_shape): + def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) view = builder.call_operator( @@ -278,7 +281,7 @@ def test_remove_nop_view(self, shape, new_shape): count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0 ) - def test_remove_nop_slice(self): + def test_remove_nop_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) slice_ = builder.call_operator( @@ -299,7 +302,7 @@ def test_remove_nop_slice(self): count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 ) - def test_remove_nop_select_before_view(self): + def test_remove_nop_select_before_view(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) select = builder.call_operator( @@ -323,7 +326,7 @@ def test_remove_nop_select_before_view(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_add(self): + def test_remove_nop_select_before_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -345,7 +348,7 @@ def test_remove_nop_select_before_add(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_mul(self): + def test_remove_nop_select_before_mul(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -367,7 +370,7 @@ def test_remove_nop_select_before_mul(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_div(self): + def test_remove_nop_select_before_div(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -389,7 +392,7 @@ def test_remove_nop_select_before_div(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_quant_dequant(self): + def test_remove_nop_quant_dequant(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 8)) q0 = builder.call_operator( @@ -441,7 +444,7 @@ def test_remove_nop_quant_dequant(self): 1, ) - def test_remove_nop_aten_linalg_vector_norm(self): + def test_remove_nop_aten_linalg_vector_norm(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 1, 128, dtype=torch.float32)) linalg_vector_norm = builder.call_operator( @@ -736,7 +739,7 @@ def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 4 ) - def test_remove_dequant_on_branch(self): + def test_remove_dequant_on_branch(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 8, 4, 6)) x = builder.call_operator(op=exir_ops.edge.aten.abs.default, args=(x,)) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 3e64a0ecd7c..2d7c2ea9edd 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest +from typing import cast import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -29,10 +30,11 @@ SinkOpsCloserToUsePass, ) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult class TestReorderPasses(unittest.TestCase): - def test_sink_dequantize(self): + def test_sink_dequantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32)) @@ -103,9 +105,10 @@ def test_sink_dequantize(self): op=exir_ops.edge.aten.cat.default, args=([abs_1, dequantize_per_tensor_1],), ) - builder.output(cat) + builder.output([cat]) original_graph = builder.get_graph_module() - converted_graph = SinkOpsCloserToUsePass()(original_graph).graph_module + p = SinkOpsCloserToUsePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it self.assertTrue( @@ -123,7 +126,7 @@ def test_sink_dequantize(self): ), ) - def test_advance_branched_quantize(self): + def test_advance_branched_quantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(64, 3, dtype=torch.float32)) view = builder.call_operator( @@ -174,9 +177,8 @@ def test_advance_branched_quantize(self): ] ) original_graph = builder.get_graph_module() - graph_module = AdvanceQuantizeOpAboveDefInBranchPass()( - original_graph - ).graph_module + p = AdvanceQuantizeOpAboveDefInBranchPass() + graph_module = cast(PassResult, p(original_graph)).graph_module graph_module.graph.eliminate_dead_code() nodes = get_compute_nodes_in_gm(graph_module) # The quantize op should be hoisted to dominate the branch @@ -208,7 +210,8 @@ def test_advance_branched_quantize(self): ), 4, ) - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + p = FuseQuantDequantToRequantizePass() + graph_module = cast(PassResult, p(graph_module)).graph_module # We expect 3 dequant/quant pairs to be removed because they have matching params, # leaving a single dequant/quant pair that is then merged into a requantize op self.assertEqual( @@ -220,7 +223,7 @@ def test_advance_branched_quantize(self): ) @torch.no_grad() - def test_advance_quantize(self): + def test_advance_quantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32)) weights = builder.placeholder( @@ -268,14 +271,13 @@ def test_advance_quantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8), ) - builder.output(dequantize_per_tensor) + builder.output([dequantize_per_tensor]) original_graph = builder.get_graph_module() - converted_graph = AdvanceQuantizeOpAboveDefInBranchPass()( - original_graph - ).graph_module - converted_graph = AdvanceQuantizeOpAboveDefChainPass()( - original_graph - ).graph_module + + p1 = AdvanceQuantizeOpAboveDefInBranchPass() + tmp_graph = cast(PassResult, p1(original_graph)).graph_module + p2 = AdvanceQuantizeOpAboveDefChainPass() + converted_graph = cast(PassResult, p2(tmp_graph)).graph_module # Assert that permute node is now the successor of the quant node. self.assertTrue( get_node_pos( @@ -284,7 +286,7 @@ def test_advance_quantize(self): < get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) ) - def test_postpone_dequantize1(self): + def test_postpone_dequantize1(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32)) weights = builder.placeholder( @@ -332,11 +334,10 @@ def test_postpone_dequantize1(self): op=exir_ops.edge.aten.permute_copy.default, args=(dequantize_per_tensor, [1, 0, 3, 2]), ) - builder.output(permute) + builder.output([permute]) original_graph = builder.get_graph_module() - converted_graph = PostponeDequantizeOpBelowUseChainPass()( - original_graph - ).graph_module + p = PostponeDequantizeOpBelowUseChainPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # Assert that dequant node is now the successor of the permute node. self.assertTrue( get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) @@ -346,7 +347,7 @@ def test_postpone_dequantize1(self): ) ) - def test_postpone_dequantize_branched(self): + def test_postpone_dequantize_branched(self) -> None: builder = GraphBuilder() x = builder.placeholder( "x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) @@ -403,14 +404,13 @@ def test_postpone_dequantize_branched(self): ) builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2]) original_graph = builder.get_graph_module() - graph_module = PostponeDequantizeOpBelowUseChainPass()( - original_graph - ).graph_module - graph_module.graph.eliminate_dead_code() + p = PostponeDequantizeOpBelowUseChainPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module + converted_graph.graph.eliminate_dead_code() # Asset that the dequant node was split into 4, one per branch self.assertEqual( count_node( - graph_module, + converted_graph, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, ), 3, @@ -419,7 +419,7 @@ def test_postpone_dequantize_branched(self): # Assert that the dequant node is no longer the predecessor of the squeeze node self.assertTrue( nodes_not_connected_in_gm( - graph_module, + converted_graph, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.aten.squeeze_copy.dims, ), @@ -427,14 +427,14 @@ def test_postpone_dequantize_branched(self): # Assert that dequant node is not predecessor of slice (it should've been moved below slice) self.assertTrue( nodes_not_connected_in_gm( - graph_module, + converted_graph, exir_ops.edge.cadence.dequantize_per_tensor.default, exir_ops.edge.aten.slice_copy.Tensor, ), ) # 4d -> permute -> 4d -> view -> 3d - def test_permute3_view4_chains(self): + def test_permute3_view4_chains(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -453,14 +453,10 @@ def test_permute3_view4_chains(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [0, 1, 3, 2]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - # Performing transform - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute nodes = get_compute_nodes_in_gm(converted_graph) @@ -471,7 +467,7 @@ def test_permute3_view4_chains(self): self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) # 3d -> permute -> 3d -> view -> 4d - def test_permute4_view3_chains(self): + def test_permute4_view3_chains(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -490,14 +486,11 @@ def test_permute4_view3_chains(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [2, 1, 0]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - # Performing transform - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute @@ -511,7 +504,7 @@ def test_permute4_view3_chains(self): # Negative test case where the transform should not happen. # permute->4d->view->3d where the view not only removes the dimension whose # size is 1 (this is ok), but also changes the size of the dimensions (not ok). - def test_permute_view_chains_neg(self): + def test_permute_view_chains_neg(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -530,14 +523,12 @@ def test_permute_view_chains_neg(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [2, 1, 0]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() + # Performing transform (nothing should happen) - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order is still view, permute, view, permute diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 41002cda009..6d12c991d6d 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -8,7 +8,7 @@ import operator import unittest -from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union +from typing import cast, List, Optional, Sequence, Tuple, Union import torch from executorch.backends.cadence.aot.graph_builder import ( @@ -47,11 +47,11 @@ ReplaceTrivialConvWithLinear, ReplaceWhereWithFullArgsWithWhereScalar, ) + +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass from executorch.exir.passes import dead_code_elimination_pass - -from parameterized.parameterized import parameterized from torch.fx.passes.infra.pass_base import PassResult @@ -59,9 +59,9 @@ class TestReplaceOpsPasses(unittest.TestCase): def assertTargetCountEqual( self, graph_module: torch.fx.GraphModule, - target: Union[Callable[..., Any], str], + target: torch.fx.node.Target, expected_count: int, - ): + ) -> None: """Helper function to check the number of nodes with a given target.""" actual_count = count_node(graph_module, target) self.assertEqual( @@ -73,13 +73,13 @@ def assertTargetCountEqual( def assertTargetCountsEqual( self, graph_module: torch.fx.GraphModule, - targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]], - ): + targets_and_counts: List[Tuple[torch.fx.node.Target, int]], + ) -> None: """Helper function to check the number of nodes of all types for a given target.""" for target, expected_count in targets_and_counts: self.assertTargetCountEqual(graph_module, target, expected_count) - @parameterized.expand( + @expand( [ ( "regular", @@ -96,7 +96,7 @@ def assertTargetCountsEqual( @torch.no_grad() def test_replace_matmul_with_transposed_matmul( self, - _, + _: str, x_shape: Tuple[int], y_shape: Tuple[int], ) -> None: @@ -132,7 +132,7 @@ def test_replace_matmul_with_transposed_matmul( 1, ) - @parameterized.expand( + @expand( [ ("2d", (3, 5), [0, 0]), # shape # padding ("3d", (20, 1, 80), [0, 0, 0]), # shape # padding @@ -141,7 +141,7 @@ def test_replace_matmul_with_transposed_matmul( @torch.no_grad() def test_replace_constant_pad_nd_with_slice( self, _, shape: Tuple[int], padding: Tuple[int] - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) matmul = builder.call_operator( @@ -162,7 +162,7 @@ def test_replace_constant_pad_nd_with_slice( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -172,7 +172,7 @@ def test_replace_constant_pad_nd_with_slice( @torch.no_grad() def test_add_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -190,7 +190,7 @@ def test_add_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -200,7 +200,7 @@ def test_add_replace_scalar_with_tensor_arg( @torch.no_grad() def test_sub_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -218,7 +218,7 @@ def test_sub_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -228,7 +228,7 @@ def test_sub_replace_scalar_with_tensor_arg( @torch.no_grad() def test_mul_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -246,7 +246,7 @@ def test_mul_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -259,7 +259,7 @@ def test_div_replace_scalar_with_tensor_arg( _, shape: Tuple[int], other: float, - ): + ) -> None: x = torch.randn(*shape) original_gm = single_op_builder( placeholders=(x,), @@ -277,7 +277,7 @@ def test_div_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["4d", (2, 3, 5, 6)], ["3d", (7, 6, 5)], @@ -288,7 +288,7 @@ def test_div_replace_scalar_with_tensor_arg( @torch.no_grad() def test_replace_functionally_equivalent_op_targets_relu( self, _, shape: Tuple[int] - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -307,16 +307,26 @@ def test_replace_functionally_equivalent_op_targets_relu( 0, ) - @parameterized.expand( - [["split_linear_tensor", (50,), i, 0] for i in range(2, 7)] - + [["split_leading_dim", (10, 2, 3), i, 0] for i in range(2, 7)] - + [["split_trailing_dim", (3, 3, 6), i, 2] for i in range(2, 6)] - + [["split_middle_dim", (3, 5, 14, 2, 3), i, 2] for i in range(2, 7)] + @expand( + [ + ("split_linear_tensor_split_size_2", (50,), 2, 0), + ("split_linear_tensor_split_size_5", (50,), 5, 0), + ("split_linear_tensor_split_size_7", (50,), 7, 0), + ("split_leading_dim_split_size_2", (10, 2, 3), 2, 0), + ("split_leading_dim_split_size_5", (10, 2, 3), 5, 0), + ("split_leading_dim_split_size_7", (10, 2, 3), 7, 0), + ("split_trailing_dim_split_size_2", (3, 3, 6), 2, 2), + ("split_trailing_dim_split_size_4", (3, 3, 6), 4, 2), + ("split_trailing_dim_split_size_6", (3, 3, 6), 6, 2), + ("split_middle_dim_split_size_2", (3, 5, 14, 2, 3), 2, 2), + ("split_middle_dim_split_size_5", (3, 5, 14, 2, 3), 5, 2), + ("split_middle_dim_split_size_7", (3, 5, 14, 2, 3), 7, 2), + ] ) @torch.no_grad() def test_replace_functionally_equivalent_op_targets_unsafe_split( self, _, shape: Tuple[int], split_size: int, dim: int - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -333,7 +343,7 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x ) - @parameterized.expand( + @expand( [ [(1, 8, 33), 8, 16, 3], [(1, 8, 33), 8, 16, 5, 2], @@ -356,7 +366,7 @@ def test_replace_transposed_conv_with_linear( depthwise: bool = False, bias_enabled: bool = True, channel_last: bool = False, - ): + ) -> None: transposed = True output_padding = [0] groups = in_channels if depthwise else 1 @@ -418,7 +428,7 @@ def test_replace_transposed_conv_with_linear( 0, ) - @parameterized.expand( + @expand( [ [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], # # depthwise @@ -442,7 +452,7 @@ def test_replace_convolution_optional_args_with_concrete_args( depthwise: bool = False, bias_enabled: bool = True, channel_last: bool = False, - ): + ) -> None: transposed = True output_padding = [0] groups = in_channels if depthwise else 1 @@ -496,7 +506,7 @@ def test_replace_convolution_optional_args_with_concrete_args( 1, ) - @parameterized.expand( + @expand( [ [(1, 2, 3), [1, 1]], [ @@ -506,7 +516,7 @@ def test_replace_convolution_optional_args_with_concrete_args( ] ) @torch.no_grad() - def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): + def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -525,7 +535,7 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): ) @torch.no_grad() - def test_replace_repeat_with_cat(self): + def test_replace_repeat_with_cat(self) -> None: x = torch.randn([3, 5]) original_gm = single_op_builder( placeholders=(x,), @@ -543,7 +553,7 @@ def test_replace_repeat_with_cat(self): 0, ) - @parameterized.expand( + @expand( [ # x, mask [(1,)], @@ -562,7 +572,7 @@ def test_replace_masked_scalar_tensor_with_full( self, shape: Tuple[int], mask_shape: Union[Tuple[int, ...], None] = None, - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) mask = builder.placeholder( @@ -602,7 +612,7 @@ def test_replace_masked_scalar_tensor_with_full( @torch.no_grad() def test_replace_scalar_tensor_with_full( self, - ): + ) -> None: original_gm = single_op_builder( placeholders=(), op=exir_ops.edge.aten.scalar_tensor.default, @@ -620,7 +630,7 @@ def test_replace_scalar_tensor_with_full( ) @torch.no_grad() - def test_replace_linear_with_fully_connected(self): + def test_replace_linear_with_fully_connected(self) -> None: shape, in_channels, out_channels = (1, 14), 14, 128 builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) @@ -661,7 +671,7 @@ def test_replace_linear_with_fully_connected(self): 0, ) - @parameterized.expand( + @expand( [ [(4, 16, 256), 256, 512, True], [(7, 17, 12), 12, 34, False], @@ -670,7 +680,7 @@ def test_replace_linear_with_fully_connected(self): @torch.no_grad() def test_replace_addmm_with_linear( self, shape: Tuple[int], in_features: int, out_features: int, bias: bool - ): + ) -> None: M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 builder = GraphBuilder() x = builder.placeholder("x", torch.randn(N, dtype=torch.float32)) @@ -704,7 +714,7 @@ def test_replace_addmm_with_linear( ) @torch.no_grad() - def test_replace_mm_with_addmm(self): + def test_replace_mm_with_addmm(self) -> None: M, K, N = 14, 48, 24 x = torch.randn([M, K]) y = torch.randn([K, N]) @@ -725,7 +735,7 @@ def test_replace_mm_with_addmm(self): 0, ) - @parameterized.expand( + @expand( [ # shape [(5, 1, 6, 7)], @@ -738,7 +748,9 @@ def test_replace_mm_with_addmm(self): ] ) @torch.no_grad() - def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): + def test_replace_squeeze_with_view( + self, shape: Tuple[int], dim: Optional[int] = None + ) -> None: x = torch.randn(shape) if dim: original_gm = single_op_builder( @@ -770,7 +782,7 @@ def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): 0, ) - @parameterized.expand( + @expand( [ # shape, dim to unsqueeze [(5, 6, 7), 0], @@ -780,7 +792,7 @@ def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): ] ) @torch.no_grad() - def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): + def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -804,7 +816,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( self, in_features: int = 16, out_features: int = 16, - ): + ) -> None: src_zero_point = 0 out_zero_point = 0 builder = GraphBuilder() @@ -873,7 +885,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ self, in_features: int = 16, out_features: int = 16, - ): + ) -> None: src_zero_point = 0 out_zero_point = 0 builder = GraphBuilder() @@ -946,7 +958,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ ) @torch.no_grad() - def test_replace_conv1d_with_linear(self): + def test_replace_conv1d_with_linear(self) -> None: x = torch.randn(1, 96, 7) weights = torch.randn(192, 96, 7) bias = torch.randn(192) @@ -957,11 +969,12 @@ def test_replace_conv1d_with_linear(self): ) # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(original_gm).graph_module + temp_graph = cast(PassResult, p1(original_gm)).graph_module + # temp_graph = p1(original_gm).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = p2(temp_graph).graph_module + graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module # Assert that conv1d is trivially converted to linear self.assertEqual( @@ -979,7 +992,7 @@ def test_replace_conv1d_with_linear(self): ) @torch.no_grad() - def test_replace_conv2d_with_linear(self): + def test_replace_conv2d_with_linear(self) -> None: x = torch.randn(1, 96, 7, 7) weights = torch.randn(192, 96, 7, 7) bias = torch.randn(192) @@ -990,11 +1003,11 @@ def test_replace_conv2d_with_linear(self): ) # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(original_gm).graph_module + temp_graph = cast(PassResult, p1(original_gm)).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = p2(temp_graph).graph_module + graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module # Assert that conv2d is trivially converted to linear self.assertEqual( @@ -1012,7 +1025,7 @@ def test_replace_conv2d_with_linear(self): ) @torch.no_grad() - def test_replace_conv2d_with_im2row_and_linear(self): + def test_replace_conv2d_with_im2row_and_linear(self) -> None: x = torch.randn(1, 96, 47, 37) weights = torch.randn(192, 96, 7, 7) bias = torch.randn(192) @@ -1035,14 +1048,16 @@ def test_replace_conv2d_with_im2row_and_linear(self): count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 ) - @parameterized.expand( + @expand( [ [(3, 1, 5), 1, 0], [(3, 4, 1), 2, -1], ] ) @torch.no_grad() - def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int): + def test_replace_select_with_view( + self, shape: Tuple[int], dim: int, index: int + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1059,7 +1074,7 @@ def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int) count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ [(2, 1, 3, 1), 1, 3, torch.float32], [(2, 1, 5), 1, 0, torch.int64], @@ -1073,7 +1088,7 @@ def test_replace_nop_transpose_with_view( dim0: int, dim1: int, dtype: torch.dtype = torch.float32, - ): + ) -> None: if dtype == torch.float32: x = torch.randn(shape) else: @@ -1094,7 +1109,7 @@ def test_replace_nop_transpose_with_view( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ # permutations that can be replaced by view [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3)], @@ -1102,7 +1117,9 @@ def test_replace_nop_transpose_with_view( ] ) @torch.no_grad() - def test_replace_nop_permute_with_view(self, shape, dims): + def test_replace_nop_permute_with_view( + self, shape: Tuple[int], dims: Tuple[int] + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1120,15 +1137,17 @@ def test_replace_nop_permute_with_view(self, shape, dims): count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ # permutations replaced by transpose - [(3, 4), [1, 0]], + [(3, 4), (1, 0)], [(3, 4, 6), (0, 2, 1)], ] ) @torch.no_grad() - def test_replace_permute_with_transpose(self, shape, dims): + def test_replace_permute_with_transpose( + self, shape: Tuple[int], dims: Tuple[int] + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1149,7 +1168,7 @@ def test_replace_permute_with_transpose(self, shape, dims): @torch.no_grad() def test_replace_permute_with_transpose_nop( self, - ): + ) -> None: x = torch.randn(3, 4) original_gm = single_op_builder( placeholders=(x,), @@ -1167,7 +1186,7 @@ def test_replace_permute_with_transpose_nop( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) - def test_replace_aten_where_with_cadence(self): + def test_replace_aten_where_with_cadence(self) -> None: builder = GraphBuilder() cond = builder.placeholder("cond", torch.randn(4, 8)) aten_gt_scalar = builder.call_operator( @@ -1202,7 +1221,7 @@ def test_replace_aten_where_with_cadence(self): 1, ) - @parameterized.expand( + @expand( [ [(4, 8), (4, 8), (4, 8), 0.0, 1.0], [(8,), (4, 8), (8,), 0.0, 1.0], @@ -1210,8 +1229,13 @@ def test_replace_aten_where_with_cadence(self): ] ) def test_replace_aten_where_with_cadence_broadcast( - self, cond_shape, a_shape, b_shape, val1, val2 - ): + self, + cond_shape: Tuple[int], + a_shape: Tuple[int], + b_shape: Tuple[int], + val1: float, + val2: float, + ) -> None: # cond_shape, a_shape, b_shape, val1, val2 = builder = GraphBuilder() cond = builder.placeholder("cond", torch.randn(cond_shape)) @@ -1243,7 +1267,7 @@ def test_replace_aten_where_with_cadence_broadcast( 1, ) - def test_no_replace_aten_gelu_with_approximate_gelu(self): + def test_no_replace_aten_gelu_with_approximate_gelu(self) -> None: inputs = torch.randn(2, 1, 64) gm = single_op_builder( @@ -1265,7 +1289,7 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self): 1, ) - def test_replace_split_with_sizes_with_slice(self): + def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) split = builder.call_operator( @@ -1291,8 +1315,8 @@ def test_replace_split_with_sizes_with_slice(self): 2, ) - @parameterized.expand([[2], [3], [4]]) - def test_replace_pow_with_mul(self, exponent: int): + @expand([[2], [3], [4]]) + def test_replace_pow_with_mul(self, exponent: int) -> None: x = torch.randn(2, 1, 64) original_gm = single_op_builder( placeholders=(x,), @@ -1316,13 +1340,13 @@ def test_replace_pow_with_mul(self, exponent: int): exponent - 1, ) - @parameterized.expand( + @expand( [ [1], [1.5], ] ) - def test_replace_pow_with_mul_not_applied(self, exponent): + def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None: x = torch.randn(2, 1, 64) original_gm = single_op_builder( placeholders=(x,), @@ -1350,7 +1374,7 @@ def test_replace_pow_with_mul_not_applied(self, exponent): class TestReplaceIm2rowWithViewPass(unittest.TestCase): - def test_no_replacement_for_conv(self): + def test_no_replacement_for_conv(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 224, 224) pad_value = torch.randn(1) @@ -1376,7 +1400,7 @@ def test_no_replacement_for_conv(self): count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 ) - def test_no_replace_for_dilation(self): + def test_no_replace_for_dilation(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 5, 7) pad_value = torch.randn(1) @@ -1401,7 +1425,7 @@ def test_no_replace_for_dilation(self): count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 ) - def test_replace_linear_like_conv(self): + def test_replace_linear_like_conv(self) -> None: # Create a graph with a single im2row node. in_h, in_w = 13, 15 x = torch.randn(1, 3, in_h, in_w) @@ -1455,7 +1479,7 @@ def create_conv1d_graphmodule( args=args, ) - def test_conv1d_default_channel_last(self): + def test_conv1d_default_channel_last(self) -> None: # Create a graph with a single convolution node. # Check if graph module is valid by running exportpass on it. gm = self.create_conv1d_graphmodule() @@ -1483,7 +1507,7 @@ def test_conv1d_default_channel_last(self): self.assertEqual(len(node.args), 8, f"{node=}") self.assertTrue(node.args[7]) - def test_conv1d_no_transpose_if_already_channel_last(self): + def test_conv1d_no_transpose_if_already_channel_last(self) -> None: gm = self.create_conv1d_graphmodule(channels_last=True) gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) @@ -1532,7 +1556,7 @@ def create_convolution_graph_module( args=args, ) - def test_convolution_default_channel_last(self): + def test_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. # Check if graph module is valid by running exportpass on it. gm = self.create_convolution_graph_module() @@ -1560,7 +1584,7 @@ def test_convolution_default_channel_last(self): self.assertEqual(len(node.args), 8, f"{node=}") self.assertTrue(node.args[7]) - def test_no_transpose_if_already_channel_last(self): + def test_no_transpose_if_already_channel_last(self) -> None: gm = self.create_convolution_graph_module(channels_last=True) gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) @@ -1637,7 +1661,7 @@ def create_quantized_convolution_graph_module( args=args, ) - def test_quantized_convolution_default_channel_last(self): + def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. gm = self.create_quantized_convolution_graph_module() self.assertEqual( @@ -1667,7 +1691,7 @@ def test_quantized_convolution_default_channel_last(self): self.assertEqual(len(node.args), 15, f"{node=}") self.assertTrue(node.args[14]) - def test_no_transpose_if_already_quantized_conv_channel_last(self): + def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. gm = self.create_quantized_convolution_graph_module(channels_last=True) # Check if graph module is valid by running exportpass on it. @@ -1710,7 +1734,7 @@ def create_slice_graph( args=(x, slice_dim, slice_begin, slice_end), ) - def test_slice_no_transpose_if_already_outermost(self): + def test_slice_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node. gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1718,7 +1742,8 @@ def test_slice_no_transpose_if_already_outermost(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. self.assertEqual( @@ -1726,7 +1751,7 @@ def test_slice_no_transpose_if_already_outermost(self): 0, ) - def test_slice_no_transpose_if_outermost_dimensions_are_one(self): + def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1734,7 +1759,8 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1743,7 +1769,7 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self): 0, ) - def test_slice_insert_transpose(self): + def test_slice_insert_transpose(self) -> None: # Create a graph with a single slice node. gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1751,7 +1777,8 @@ def test_slice_insert_transpose(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that there are two transpose ops added. self.assertEqual( @@ -1771,7 +1798,7 @@ def create_cat_graph( args=(input_tensors, cat_dim), ) - def test_cat_no_transpose_if_already_outermost(self): + def test_cat_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) # Check if graph module is valid by running exportpass on it. @@ -1779,7 +1806,8 @@ def test_cat_no_transpose_if_already_outermost(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1788,7 +1816,7 @@ def test_cat_no_transpose_if_already_outermost(self): 0, ) - def test_cat_no_transpose_if_outermost_dimensions_are_one(self): + def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) # Check if graph module is valid by running exportpass on it. @@ -1796,7 +1824,8 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1805,7 +1834,7 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self): 0, ) - def test_cat_insert_transpose(self): + def test_cat_insert_transpose(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph( input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 @@ -1815,7 +1844,8 @@ def test_cat_insert_transpose(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that transpose ops were added to make cat on outermost dimension. self.assertEqual( @@ -1841,7 +1871,7 @@ def _get_slice_empty_gm(self) -> torch.fx.GraphModule: builder.output([cat]) return builder.get_graph_module() - def test_empty_slice(self): + def test_empty_slice(self) -> None: gm = self._get_slice_empty_gm() self.assertEqual( len( @@ -1859,7 +1889,8 @@ def test_empty_slice(self): ), 0, ) - updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module + p = ReplaceEmptyTensorsWithFullPass() + updated_gm = cast(PassResult, p(gm)).graph_module self.assertEqual( len( updated_gm.graph.find_nodes( @@ -1877,14 +1908,16 @@ def test_empty_slice(self): 1, ) - @parameterized.expand( + @expand( [ ("int", int(123)), ("float", float(456.0)), ], ) @torch.no_grad() - def test_extract_mul_argument_to_full(self, _, value) -> None: + def test_extract_mul_argument_to_full( + self, _: str, value: Union[int, float] + ) -> None: x = torch.randn(2, 1, 64) gm = single_op_builder( placeholders=(x,), diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py index 195c0ff00ab..f26fe897e1e 100644 --- a/backends/cadence/aot/tests/test_simplify_ops_passes.py +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest @@ -18,13 +18,13 @@ BindOptionalArgsPass, SimplifySliceOpPass, ) +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from parameterized.parameterized import parameterized from torch.fx.passes.infra.pass_base import PassResult class TestSimplifyOpsPasses(unittest.TestCase): - @parameterized.expand( + @expand( [ [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], ] @@ -38,7 +38,7 @@ def test_simplify_slice_scatter_op( start: Optional[int] = None, end: Optional[int] = None, step: int = 1, - ): + ) -> None: x = torch.randn(*in_shape) y = torch.randn(*src_shape) gm = single_op_builder( @@ -50,7 +50,7 @@ def test_simplify_slice_scatter_op( gm = cast(PassResult, p(gm)).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0) - @parameterized.expand( + @expand( [ [(3, 16, 5), 1, 15, 3, 3], ] @@ -63,7 +63,7 @@ def test_simplify_slice_op( start: Optional[int] = None, end: Optional[int] = None, step: int = 1, - ): + ) -> None: x = torch.randn(*in_shape) gm = single_op_builder( placeholders=(x,), diff --git a/backends/cadence/aot/typing_stubs.py b/backends/cadence/aot/typing_stubs.py new file mode 100644 index 00000000000..f15628f7948 --- /dev/null +++ b/backends/cadence/aot/typing_stubs.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + + # This only runs during static type checking (not at runtime) + def expand(arg: object) -> Callable[..., None]: ... + +else: + # Real import used at runtime + # from parameterized.parameterized import parameterized.expand as expand # noqa + from parameterized.parameterized import parameterized + + expand = parameterized.expand