Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class EdgeCompileConfig:
# TODO(larryliu): remove this
_use_edge_ops: bool = True
_skip_type_promotion: bool = False
# TODO(gasoonjia): set it as False by default, and remove it in the long term
_skip_dim_order: bool = True


@compatibility(is_backward_compatible=False)
Expand Down
10 changes: 9 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
EdgeToBackendOpsPass,
MemoryFormatOpsPass,
OpReplacePass,
post_op_replace_passes,
pre_op_replace_passes,
Expand Down Expand Up @@ -494,6 +495,7 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
passes = pre_op_replace_passes + (
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
)

new_ep = copy.deepcopy(ep).transform(*passes)
if dialect == "ATEN":
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
Expand All @@ -504,7 +506,11 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
assert new_gm_res is not None
new_gm = new_gm_res.graph_module

for p in post_op_replace_passes:
passes = post_op_replace_passes + (
[] if config._skip_dim_order else [MemoryFormatOpsPass]
)

for p in passes:
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
Expand Down Expand Up @@ -837,6 +843,8 @@ def to_edge(
passes.append(RemoveMixedTypeOperators())
if config._use_edge_ops:
passes.append(OpReplacePass())
if not config._skip_dim_order:
passes.append(MemoryFormatOpsPass())

gm = program.graph_module
for p in passes:
Expand Down
171 changes: 113 additions & 58 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,92 +6,147 @@

import unittest
from dataclasses import dataclass
from typing import Any, List, Tuple
from typing import Any, Tuple

import torch
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.passes import MemoryFormatOpsPass
from torch.export import export
from torch.testing import FileCheck


@dataclass
class MemoryFormatTestSet:
module: torch.nn.Module
sample_input: Tuple[Any, ...]
target_memory_format: torch.memory_format


class TestMemoryFormatOpsPass(unittest.TestCase):
def test_op_to_copy_replacement(self) -> None:
class ContiguousModule(torch.nn.Module):
def is_channel_last(self, x: torch.Tensor):
# This is a heuristic to determine if the input tensor is in NHWC (channel last)
# due to we do not have a good way to infer the dimension order or the memory format
# of the input tensor. Please not this function is specific for contiguous tensors
# whose dim(1) is channel one only, other types of tensors may not work well
# due to different channel configuration and memory arrangement.

return x.stride(1) == 1

def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

before = export(test_set.module, test_set.sample_input)

# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
).run(before.graph_module.code)

# TODO(gasoonjia): make to_dim_copy pass verifier
epm = to_edge(
before,
compile_config=EdgeCompileConfig(
_check_ir_validity=False, _skip_dim_order=False
),
)

# check op strings
FileCheck().check_not(aten_op_str).check_count(
edge_op_str, 1, exactly=True
).run(epm.exported_program().graph_module.code)

# check EdgeOp and the new BackendOp should behave the same
expected = before(*test_set.sample_input)
actual = epm.exported_program()(*test_set.sample_input)
self.assertTrue(torch.allclose(actual, expected))
self.assertEqual(
self.is_channel_last(actual),
self.is_channel_last(expected),
)
if test_set.target_memory_format == torch.channels_last:
self.assertTrue(self.is_channel_last(actual))
elif test_set.target_memory_format == torch.contiguous_format:
self.assertFalse(self.is_channel_last(actual))
else:
raise RuntimeError("Unknown memory format")

# TODO - more
epm.to_executorch()

def test_op_to_copy_replacement_2d(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.contiguous_format)

class ChannelsLastModule(torch.nn.Module):
self.memory_format_test_runner(
MemoryFormatTestSet(
module=Module().eval(),
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
)
)

def test_op_to_copy_replacement_4d(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)
return x.to(dtype=torch.double, memory_format=torch.contiguous_format)

@dataclass
class TestSet:
module: torch.nn.Module
sample_input: Tuple[Any, ...]
self.memory_format_test_runner(
MemoryFormatTestSet(
module=Module().eval(),
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
)
)

contiguous_module = ContiguousModule().eval()
channels_last_module = ChannelsLastModule().eval()
def test_op_dim_order_update(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()

all_test_sets: List[TestSet] = [
TestSet(
module=contiguous_module,
sample_input=(torch.randn([2, 2], dtype=torch.float32),),
),
TestSet(
module=contiguous_module,
sample_input=(torch.randn([2, 2, 2], dtype=torch.float32),),
),
TestSet(
module=channels_last_module,
sample_input=(torch.randn([2, 2, 2, 2], dtype=torch.float32),),
),
TestSet(
module=channels_last_module,
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)

self.memory_format_test_runner(
MemoryFormatTestSet(
module=Module().eval(),
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
dtype=torch.float32,
memory_format=torch.channels_last,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
),
]

aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

for test_set in all_test_sets:
before = export(test_set.module, test_set.sample_input)

# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
).run(before.graph_module.code)

ep = to_edge(
before, compile_config=EdgeCompileConfig(_check_ir_validity=False)
) # Only replacing edge_ops

# Run the pass
# TODO move this in to_edge passes, make to_dim_copy pass verifier
after = ep.transform([MemoryFormatOpsPass()], check_ir_validity=False)
)

# check op strings
FileCheck().check_not(aten_op_str).check_count(
edge_op_str, 1, exactly=True
).run(after.exported_program().graph_module.code)
def test_op_dim_order_propagation(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()

# check EdgeOp and the new BackendOp should behave the same
expected = before(*test_set.sample_input)
actual = after.exported_program()(*test_set.sample_input)
self.assertTrue(torch.allclose(actual, expected))
def forward(self, x: torch.Tensor) -> torch.Tensor:
t1 = x.to(dtype=torch.double, memory_format=torch.channels_last)
t2 = t1 + t1
return t1 * t2

# TODO - more
after.to_executorch()
self.memory_format_test_runner(
MemoryFormatTestSet(
module=Module().eval(),
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
dtype=torch.float32,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
)
)