From f0ad62eeb3ec4447d8bcb0d87d7f0075b100f228 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Thu, 22 May 2025 09:34:51 -0700 Subject: [PATCH] Use GraphBuilder in unit tests. (#10977) Summary: Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng Differential Revision: D74907087 --- .../aot/tests/test_fusion_ops_passes.py | 172 +++++++++++------- 1 file changed, 109 insertions(+), 63 deletions(-) diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 8f888c4c8bf..fff2963df29 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -13,10 +13,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot import compiler -from executorch.backends.cadence.aot.compiler import ( - export_to_edge, - quantize_and_export_to_edge, -) from executorch.backends.cadence.aot.fuse_ops import ( FuseFullThenReshapePass, FuseMulScalarIntoDequantPass, @@ -336,29 +332,25 @@ def test_replace_quant_view_dequant_with_requantize(self): ) def test_replace_dequant_quant_with_requantize(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - x = torch.ops.quantized_decomposed.dequantize_per_tensor( - x, 1.2, 3, 0, 127, torch.int8 - ) - x = torch.permute(x, [2, 0, 1, 3]) - x = torch.ops.quantized_decomposed.quantize_per_tensor( - x, 4.5, 6, 0, 127, torch.int8 - ) - return x - - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) - model = M() - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) + dequant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(x, 1.2, 3, 0, 127, torch.int8), + ) + quant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(dequant, 4.5, 6, 0, 127, torch.int8), + ) + builder.output(quant) + graph_module = FuseQuantDequantToRequantizePass()( + builder.get_graph_module() + ).graph_module self.check_op_counts( graph_module, expected_op_counts={ - # Verify that dequant -> permute -> quant was replaced with permute -> requantize. + # Verify that dequant -> quant was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, exir_ops.edge.cadence.requantize.default: 1, @@ -366,24 +358,23 @@ def forward(self, x): ) def test_replace_dequant_permute_quant_with_requantize(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - x = torch.ops.quantized_decomposed.dequantize_per_tensor( - x, 1.2, 3, 0, 127, torch.int8 - ) - x = torch.permute(x, [2, 0, 1, 3]) - x = torch.ops.quantized_decomposed.quantize_per_tensor( - x, 4.5, 6, 0, 127, torch.int8 - ) - return x - - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) - model = M() - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) + dequant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(x, 1.2, 3, 0, 127, torch.int8), + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3]) + ) + quant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(permute, 4.5, 6, 0, 127, torch.int8), + ) + builder.output(quant) + graph_module = FuseQuantDequantToRequantizePass()( + builder.get_graph_module() + ).graph_module self.check_op_counts( graph_module, @@ -391,39 +382,94 @@ def forward(self, x): # Verify that dequant -> permute -> quant was replaced with permute -> requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.aten.permute_copy.default: 1, exir_ops.edge.cadence.requantize.default: 1, }, ) def test_remove_nop_dequant_quant(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.lin1 = torch.nn.Linear(6, 12, bias=False) - self.lin2 = torch.nn.Linear(12, 24, bias=False) + LEADING_DIMS: Final[int] = 12 + IN_DIM: Final[int] = 6 + OUT_DIM: Final[int] = 12 - def forward(self, x): - x = self.lin1(x) - # redundant dequant+quant will be created around this permute - x = torch.permute(x, [0, 2, 1, 3]) - x = self.lin2(x) - return x - - inputs = torch.randn(2, 12, 1, 6) - model = M() - graph_module = ( - quantize_and_export_to_edge(model, (inputs,)) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder( + "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) + ) + quant1 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 4.5, 6, 0, 127, torch.int8), + ) + weights = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) + ) + bias = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + ) + weight_zero_point = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) + ) + out_multiplier = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + ) + out_shift = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) ) - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + linear1 = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quant1, + weights, + bias, + 0, # src_zero_point + weight_zero_point, + out_multiplier, + out_shift, + 0, # out_zero_point + None, + ), + ) + dequant1 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(linear1, 1.2, 3, 0, 127, torch.int8), + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0]) + ) + quant2 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(permute, 4.5, 6, 0, 127, torch.int8), + ) + linear2 = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quant2, + weights, + bias, + 0, # src_zero_point + weight_zero_point, + out_multiplier, + out_shift, + 0, # out_zero_point + None, + ), + ) + dequant2 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(linear2, 1.2, 3, 0, 127, torch.int8), + ) + builder.output(dequant2) + graph_module = FuseQuantDequantToRequantizePass()( + builder.get_graph_module() + ).graph_module self.check_op_counts( graph_module, expected_op_counts={ - # Verify that one dequant/quant pair was removed - # Expect 1 quantize ops: 1 input + # Verify that one dequant/quant pair was removed from chain: + # quant->linear->dequant->permute->quant->linear->dequant + # gets converted to: + # quant->linear->permute->linear->dequant exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, - # Expect 1 dequant op at the end (output of second linear) exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, }, )