diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 04c519a8ab8..dc8cec8aac2 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -154,6 +154,7 @@ python_library( deps = [ "//caffe2:torch", "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", ], ) diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index 943794ab1b0..5394bed5a28 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -6,9 +6,41 @@ # pyre-strict +from typing import List, Tuple + import torch -from executorch.exir.pass_base import ExportPass, ProxyValue -from torch.utils import _pytree as pytree +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule + +_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, +) +_QUANT_OPS: Tuple[torch._ops.OpOverload] = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, +) + + +def eliminate_dq_q( + graph_module: GraphModule, + dequant_nodes: List[torch.fx.Node], +) -> None: + for node in dequant_nodes: + assert node.target in _DEQUANT_OPS + for user in list(node.users): + if user.target in _QUANT_OPS: + # Drop the input arg and check that the qparams are the same. + qparams_dq = list(node.args)[1:] + qparams_q = list(user.args)[1:] + if qparams_dq != qparams_q: + continue + user.replace_all_uses_with(node.args[0]) class RemoveNoopPass(ExportPass): @@ -16,28 +48,45 @@ class RemoveNoopPass(ExportPass): Removes noops that pass through arguments. """ - # pyre-ignore - def call_operator(self, op, args, kwargs, meta): - if op not in ( - torch.ops.aten.to.dtype, - torch.ops.aten.dropout.default, - torch.ops.aten.slice_copy.Tensor, - ): - return super().call_operator(op, args, kwargs, meta) - - args_data, kwargs_data = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - orig_tensor = ( - args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - ) - - if orig_tensor is op(*args_data, **kwargs_data): - return args[0] - - if op == torch.ops.aten.slice_copy.Tensor: - result = op(*args_data, **kwargs_data) - if orig_tensor.size() == result.size(): - return args[0] - - return super().call_operator(op, args, kwargs, meta) + def call(self, graph_module: GraphModule) -> PassResult: + + # In this list we'll collect all the dequant nodes that are inputs to ops that + # are removed in this pass and later check for redundant dq->q patterns and + # remove them. + dequant_nodes = [] + + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target not in ( + torch.ops.aten.to.dtype, + torch.ops.aten.dropout.default, + torch.ops.aten.slice_copy.Tensor, + ): + continue + + orig_tensor = node.args[0].meta["val"] + + if orig_tensor is node.meta["val"]: + # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. + # Otherwise, removing only the op will suffice. + if node.args[0].target in _DEQUANT_OPS: + dequant_nodes += [node.args[0]] + node.replace_all_uses_with(node.args[0]) + continue + + if node.target == torch.ops.aten.slice_copy.Tensor: + if orig_tensor.size() == node.meta["val"].size(): + # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. + # Otherwise, removing only the op will suffice. + if node.args[0].target in _DEQUANT_OPS: + dequant_nodes += [node.args[0]] + node.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + eliminate_dq_q(graph_module, dequant_nodes) + graph_module.graph.lint() + graph_module.graph.eliminate_dead_code() + + return PassResult(graph_module, True) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 8512073428d..06f2b21ea00 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -15,7 +15,7 @@ # Import passes import executorch.exir.memory_planning # noqa import torch -from executorch.exir import EdgeCompileConfig, memory, to_edge +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.emit import emit_program @@ -50,6 +50,12 @@ from functorch.experimental import control_flow from torch import nn + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from torch.export import export from torch.fx import GraphModule, subgraph_rewriter from torch.fx.experimental.proxy_tensor import make_fx @@ -1244,3 +1250,173 @@ def forward(self, x): # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {}) # return (copy__default, aten_add_tensor) self.assertEqual(count_copies(gm), 1) + + def test_remove_quantized_op_noop_pass(self) -> None: + class TestAddSliceNoop(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x + x + x = x + x[:] + return x + + class TestAddSliceNotNoop(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x + x + x = x + x[:1] + return x + + def count_dq_nodes(gm: torch.fx.GraphModule) -> int: + return sum( + ( + node.target + in ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ) + ) + for node in gm.graph.nodes + ) + + def count_q_nodes(gm: torch.fx.GraphModule) -> int: + return sum( + ( + node.target + in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ) + ) + for node in gm.graph.nodes + ) + + def quantize_model( + m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor] + ) -> Tuple[EdgeProgramManager, int, int]: + # program capture + m = torch._export.capture_pre_autograd_graph( + m_eager, + example_inputs, + ) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config() + quantizer.set_global(quantization_config) + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m, fold_quantize=True) + ep = torch.export.export(m, example_inputs) + dq_nodes_pre = count_dq_nodes(ep.graph_module) + q_nodes_pre = count_q_nodes(ep.graph_module) + edge = to_edge( + ep, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ) + return edge, dq_nodes_pre, q_nodes_pre + + example_inputs = (torch.randn(9, 8),) + model = TestAddSliceNoop() + m_eager = model.eval() + edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) + + dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) + q_nodes_post = count_q_nodes(edge.exported_program().graph_module) + # One dq and one q node around the slice copy should have been removed. + self.assertEqual(dq_nodes_pre - dq_nodes_post, 1) + self.assertEqual(q_nodes_pre - q_nodes_post, 1) + + # Check that the slice_copy is removed by the RemoveNoopPass. + for node in edge.exported_program().graph_module.graph.nodes: + self.assertFalse("slice" in str(node.target)) + + model = TestAddSliceNotNoop() + m_eager = model.eval() + edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) + + dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) + q_nodes_post = count_q_nodes(edge.exported_program().graph_module) + # One dq and one q node around the slice copy should have been removed. + self.assertEqual(dq_nodes_pre, dq_nodes_post) + self.assertEqual(q_nodes_pre, q_nodes_post) + + # Check that the slice_copy is not removed by the RemoveNoopPass. + self.assertTrue( + any( + "slice" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + ) + + def test_dq_q_no_op_pass(self) -> None: + class TestDqQ(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 1.0, 0, -128, 127, torch.int8 + ) + q = torch.ops.quantized_decomposed.quantize_per_tensor.default( + dq, 1.0, 0, -128, 127, torch.int8 + ) + return q + + model = TestDqQ() + m_eager = model.eval() + ep = torch.export.export(m_eager, (torch.randn(9, 8),)) + edge = to_edge(ep) + # Check that the dq and q nodes are not touched by the RemoveNoopPass. + self.assertTrue( + any( + "dequantize" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + ) + self.assertTrue( + any( + "quantize" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + ) + + def test_dq_q_different_qparams(self) -> None: + class TestDqQDifferentQParam(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 1.0, 0, -128, 127, torch.int8 + ) + slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0) + q = torch.ops.quantized_decomposed.quantize_per_tensor.default( + slice_copy_output, 1.0, 0, -127, 127, torch.int8 + ) + return q + + model = TestDqQDifferentQParam() + m_eager = model.eval() + ep = torch.export.export(m_eager, (torch.randn(9, 8),)) + edge = to_edge(ep) + print(edge.exported_program().graph_module.graph) + # Check that the dq and q nodes are not touched by the RemoveNoopPass. + self.assertTrue( + any( + "dequantize" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + ) + self.assertTrue( + any( + "quantize" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + ) + self.assertFalse( + any( + "slice" in str(node.target) + for node in edge.exported_program().graph_module.graph.nodes + ) + )