From 42d706110320eabdcc1c3014d9effc68ff39ab27 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 8 Feb 2024 22:21:20 -0800 Subject: [PATCH 1/2] move memory format pass into to_edge (#1891) Summary: This diff enable memory foramt pass by moving it into to_edge. Also introduced `_skip_dim_order` in edge compile config for gradually enable the pass. Currently we set its default as True, only enable it in `test_memory_format_ops_pass` for evaluation. Will graduatlly enable it in our system, and finally remove it from EdgeCompileConfig Differential Revision: D53567636 --- exir/capture/_config.py | 2 ++ exir/program/_program.py | 10 +++++++++- exir/tests/test_memory_format_ops_pass.py | 20 ++++++++++---------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index b9bcf8884ff..c0a33c2424d 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -35,6 +35,8 @@ class EdgeCompileConfig: # TODO(larryliu): remove this _use_edge_ops: bool = True _skip_type_promotion: bool = False + # TODO(gasoonjia): set it as False by default, and remove it in the long term + _skip_dim_order: bool = True @compatibility(is_backward_compatible=False) diff --git a/exir/program/_program.py b/exir/program/_program.py index b030a79214d..b4f66eea164 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -21,6 +21,7 @@ from executorch.exir.pass_manager import PassType from executorch.exir.passes import ( EdgeToBackendOpsPass, + MemoryFormatOpsPass, OpReplacePass, post_op_replace_passes, pre_op_replace_passes, @@ -494,6 +495,7 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": passes = pre_op_replace_passes + ( [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] ) + new_ep = copy.deepcopy(ep).transform(*passes) if dialect == "ATEN": new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program) @@ -504,7 +506,11 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": assert new_gm_res is not None new_gm = new_gm_res.graph_module - for p in post_op_replace_passes: + passes = post_op_replace_passes + ( + [] if config._skip_dim_order else [MemoryFormatOpsPass] + ) + + for p in passes: new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module @@ -837,6 +843,8 @@ def to_edge( passes.append(RemoveMixedTypeOperators()) if config._use_edge_ops: passes.append(OpReplacePass()) + if not config._skip_dim_order: + passes.append(MemoryFormatOpsPass()) gm = program.graph_module for p in passes: diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 1f1b24a5ef9..6f56c5c8dd1 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -75,23 +75,23 @@ class TestSet: edge_op_str ).run(before.graph_module.code) - ep = to_edge( - before, compile_config=EdgeCompileConfig(_check_ir_validity=False) - ) # Only replacing edge_ops - - # Run the pass - # TODO move this in to_edge passes, make to_dim_copy pass verifier - after = ep.transform([MemoryFormatOpsPass()], check_ir_validity=False) + # TODO(gasoonjia): make to_dim_copy pass verifier + epm = to_edge( + before, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=False + ), + ) # check op strings FileCheck().check_not(aten_op_str).check_count( edge_op_str, 1, exactly=True - ).run(after.exported_program().graph_module.code) + ).run(epm.exported_program().graph_module.code) # check EdgeOp and the new BackendOp should behave the same expected = before(*test_set.sample_input) - actual = after.exported_program()(*test_set.sample_input) + actual = epm.exported_program()(*test_set.sample_input) self.assertTrue(torch.allclose(actual, expected)) # TODO - more - after.to_executorch() + epm.to_executorch() From 66395ea1e493e52581e7605c75413cd69292339e Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 8 Feb 2024 22:21:20 -0800 Subject: [PATCH 2/2] make memory format pass test support dim order propagation (#1900) Summary: Add more test to dim order pass Differential Revision: D53593453 --- exir/tests/test_memory_format_ops_pass.py | 167 ++++++++++++++-------- 1 file changed, 111 insertions(+), 56 deletions(-) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 6f56c5c8dd1..0afc5efba87 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -6,92 +6,147 @@ import unittest from dataclasses import dataclass -from typing import Any, List, Tuple +from typing import Any, Tuple import torch from executorch.exir import EdgeCompileConfig, to_edge -from executorch.exir.passes import MemoryFormatOpsPass from torch.export import export from torch.testing import FileCheck +@dataclass +class MemoryFormatTestSet: + module: torch.nn.Module + sample_input: Tuple[Any, ...] + target_memory_format: torch.memory_format + + class TestMemoryFormatOpsPass(unittest.TestCase): - def test_op_to_copy_replacement(self) -> None: - class ContiguousModule(torch.nn.Module): + def is_channel_last(self, x: torch.Tensor): + # This is a heuristic to determine if the input tensor is in NHWC (channel last) + # due to we do not have a good way to infer the dimension order or the memory format + # of the input tensor. Please not this function is specific for contiguous tensors + # whose dim(1) is channel one only, other types of tensors may not work well + # due to different channel configuration and memory arrangement. + + return x.stride(1) == 1 + + def memory_format_test_runner(self, test_set: MemoryFormatTestSet): + aten_op_str = "torch.ops.aten._to_copy.default" + edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" + + before = export(test_set.module, test_set.sample_input) + + # check op strings before + FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( + edge_op_str + ).run(before.graph_module.code) + + # TODO(gasoonjia): make to_dim_copy pass verifier + epm = to_edge( + before, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=False + ), + ) + + # check op strings + FileCheck().check_not(aten_op_str).check_count( + edge_op_str, 1, exactly=True + ).run(epm.exported_program().graph_module.code) + + # check EdgeOp and the new BackendOp should behave the same + expected = before(*test_set.sample_input) + actual = epm.exported_program()(*test_set.sample_input) + self.assertTrue(torch.allclose(actual, expected)) + self.assertEqual( + self.is_channel_last(actual), + self.is_channel_last(expected), + ) + if test_set.target_memory_format == torch.channels_last: + self.assertTrue(self.is_channel_last(actual)) + elif test_set.target_memory_format == torch.contiguous_format: + self.assertFalse(self.is_channel_last(actual)) + else: + raise RuntimeError("Unknown memory format") + + # TODO - more + epm.to_executorch() + + def test_op_to_copy_replacement_2d(self) -> None: + class Module(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.contiguous_format) - class ChannelsLastModule(torch.nn.Module): + self.memory_format_test_runner( + MemoryFormatTestSet( + module=Module().eval(), + sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + ) + ) + + def test_op_to_copy_replacement_4d(self) -> None: + class Module(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.to(dtype=torch.double, memory_format=torch.channels_last) + return x.to(dtype=torch.double, memory_format=torch.contiguous_format) - @dataclass - class TestSet: - module: torch.nn.Module - sample_input: Tuple[Any, ...] + self.memory_format_test_runner( + MemoryFormatTestSet( + module=Module().eval(), + sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + ) + ) - contiguous_module = ContiguousModule().eval() - channels_last_module = ChannelsLastModule().eval() + def test_op_dim_order_update(self) -> None: + class Module(torch.nn.Module): + def __init__(self): + super().__init__() - all_test_sets: List[TestSet] = [ - TestSet( - module=contiguous_module, - sample_input=(torch.randn([2, 2], dtype=torch.float32),), - ), - TestSet( - module=contiguous_module, - sample_input=(torch.randn([2, 2, 2], dtype=torch.float32),), - ), - TestSet( - module=channels_last_module, - sample_input=(torch.randn([2, 2, 2, 2], dtype=torch.float32),), - ), - TestSet( - module=channels_last_module, + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.to(dtype=torch.double, memory_format=torch.channels_last) + + self.memory_format_test_runner( + MemoryFormatTestSet( + module=Module().eval(), sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), dtype=torch.float32, - memory_format=torch.channels_last, + memory_format=torch.contiguous_format, ), ), + target_memory_format=torch.channels_last, ), - ] + ) - aten_op_str = "torch.ops.aten._to_copy.default" - edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - - for test_set in all_test_sets: - before = export(test_set.module, test_set.sample_input) + def test_op_dim_order_propagation(self) -> None: + class Module(torch.nn.Module): + def __init__(self): + super().__init__() - # check op strings before - FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( - edge_op_str - ).run(before.graph_module.code) + def forward(self, x: torch.Tensor) -> torch.Tensor: + t1 = x.to(dtype=torch.double, memory_format=torch.channels_last) + t2 = t1 + t1 + return t1 * t2 - # TODO(gasoonjia): make to_dim_copy pass verifier - epm = to_edge( - before, - compile_config=EdgeCompileConfig( - _check_ir_validity=False, _skip_dim_order=False + self.memory_format_test_runner( + MemoryFormatTestSet( + module=Module().eval(), + sample_input=( + torch.rand_like( + torch.zeros([2, 2, 2, 2]), + dtype=torch.float32, + memory_format=torch.contiguous_format, + ), ), + target_memory_format=torch.channels_last, ) - - # check op strings - FileCheck().check_not(aten_op_str).check_count( - edge_op_str, 1, exactly=True - ).run(epm.exported_program().graph_module.code) - - # check EdgeOp and the new BackendOp should behave the same - expected = before(*test_set.sample_input) - actual = epm.exported_program()(*test_set.sample_input) - self.assertTrue(torch.allclose(actual, expected)) - - # TODO - more - epm.to_executorch() + )