Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(self, x):
return self.block(x)


@unittest.skip("Clones are optimized out of the graph.")
class TestCloneConverter(unittest.TestCase):
__test__ = False # Prevent interfering with PyTest tests

Expand Down
30 changes: 2 additions & 28 deletions backends/transforms/test/test_remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest

import torch
Expand Down Expand Up @@ -164,34 +166,6 @@ def test_clone_non_identity_survives(self):
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_clone_identity_removed(self):
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""

for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
before_epm.exported_program().graph_module.code
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_not(clone_op_str).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)


if __name__ == "__main__":
unittest.main()
62 changes: 36 additions & 26 deletions exir/passes/remove_noop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
dequant_nodes = []

for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target not in (
torch.ops.aten.to.dtype,
torch.ops.aten.dropout.default,
torch.ops.aten.slice_copy.Tensor,
):
continue

orig_tensor = node.args[0].meta["val"]

if orig_tensor is node.meta["val"]:
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
# Otherwise, removing only the op will suffice.
if RemoveNoopPass._should_remove_node(node):
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])
continue

if node.target == torch.ops.aten.slice_copy.Tensor:
# Only do this check if all the dims are static.
if all(isinstance(dim, int) for dim in orig_tensor.size()):
if orig_tensor.shape == node.meta["val"].shape:
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
# Otherwise, removing only the op will suffice.
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])

graph_module.graph.eliminate_dead_code()
eliminate_dq_q(graph_module, dequant_nodes)
Expand All @@ -93,6 +68,41 @@ def call(self, graph_module: GraphModule) -> PassResult:

return PassResult(graph_module, True)

@staticmethod
def _should_remove_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False

input_meta_val = (
node.args[0].meta.get("val", None)
if len(node.args) > 0 and hasattr(node.args[0], "meta")
else None
)

if input_meta_val is not None:
if node.target in (
torch.ops.aten.to.dtype,
torch.ops.aten.dropout.default,
):
return input_meta_val is node.meta["val"]
elif node.target == torch.ops.aten.slice_copy.Tensor:
# Only do this check if all the dims are static.
return (
all(isinstance(dim, int) for dim in input_meta_val.size())
and input_meta_val.shape == node.meta["val"].shape
)
elif node.target == torch.ops.aten.clone.default:
# Remove if memory_format=None, preserve_format, or input already has the target memory format.
dest_memory_format = (
node.kwargs.get("memory_format", None) or torch.preserve_format
)
return (
dest_memory_format == torch.preserve_format
or input_meta_val.is_contiguous(memory_format=dest_memory_format)
)

return False


class RemoveToCopyPass(ExportPass):
"""
Expand Down
32 changes: 32 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,3 +2093,35 @@ def forward(self, x):
prop_tensor.is_contiguous(),
f"Propagated tensor is not contiguous: {prop_tensor.stride()}",
)

def test_remove_noop_pass_clone(self) -> None:
"""
Verify the no-op clones are removed from the graph.
"""

class CloneModel(torch.nn.Module):
def forward(self, x):
return x.clone() + x.clone()

model = CloneModel()
inputs = (torch.randn(1, 16),)

ep = torch.export.export(model, inputs)
lowered = to_edge_transform_and_lower(ep)

# Sanity check the test - we should see clones in the exported program
self.assertTrue(
any(
n.op == "call_function" and n.target == torch.ops.aten.clone.default
for n in ep.graph.nodes
)
)

# Since the clone ops are no-ops, they should be gone.
self.assertFalse(
any(
n.op == "call_function"
and n.target == exir_ops.edge.dim_order_ops._clone_dim_order.default
for n in lowered.exported_program().graph.nodes
)
)
Loading