diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index f97236bed7b..5d76ecd2d54 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -6,10 +6,8 @@ # pyre-unsafe -import unittest - from collections import Counter -from typing import Dict, Tuple +from typing import Tuple import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( @@ -33,19 +31,15 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, TemporaryFileName, - TestCase, ) from torchao.quantization.pt2e import ( allow_exported_model_train_eval, compare_results, - CUSTOM_KEY, extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, + FROM_NODE_KEY, prepare_for_propagation_comparison, ) -from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -53,7 +47,10 @@ ) from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer -from torchao.testing.pt2e.utils import PT2EQuantizationTestCase +from torchao.testing.pt2e.utils import ( + PT2ENumericDebuggerTestCase, + PT2EQuantizationTestCase, +) class TestQuantizePT2E(PT2EQuantizationTestCase): @@ -495,7 +492,8 @@ def forward(self, x): for n in m.graph.nodes: if n.op == "get_attr" and "frozen_param" in n.target: for key in n.meta: - self.assertEqual(n.meta[key], weight_meta[key]) + if key != FROM_NODE_KEY: + self.assertEqual(n.meta[key], weight_meta[key]) def test_reentrant(self) -> None: """Test we can safely call quantization apis multiple times""" @@ -725,76 +723,59 @@ def test_save_load(self) -> None: instantiate_parametrized_tests(TestQuantizePT2E) -@unittest.skip("TODO: Reenable it after debug infrature finish update") -class TestNumericDebugger(TestCase): - def _extract_debug_handles(self, model) -> Dict[str, int]: - debug_handle_map: Dict[str, int] = {} - - def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: - nonlocal debug_handle_map - if ( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] - ): - debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ - NUMERIC_DEBUG_HANDLE_KEY - ] - - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) - return debug_handle_map - - def _assert_each_node_has_debug_handle(self, model) -> None: - def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: - self.assertTrue( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], - f"Node {node} doesn't have debug handle", - ) - - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) +class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase): - def test_quantize_pt2e_preserve_handle(self) -> None: + def test_quantize_pt2e_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) m = ep.module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=False) ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - repeated_debug_handle_ids = [1, 2, 3] - # 3 ids were repeated because we copy over the id from node to its output observer + m = prepare_pt2e(m, quantizer) + from_node_source_map = self._extract_from_node_source(m) + node_name_equip_with_output_observer = [ + "conv2d", + "conv1d", + "squeeze", + ] + res_counter = Counter(from_node_source_map.values()) + repeated_from_node_source = [ + from_node_source_map[n_name] + for n_name in node_name_equip_with_output_observer + ] + # 3 infos were repeated because we copy over the info from node to its output observer # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) + for from_node_source in repeated_from_node_source: + self.assertEqual(res_counter[from_node_source], 2) m(*example_inputs) m = convert_pt2e(m) - self._assert_each_node_has_debug_handle(ep) - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - # same set of ids where repeated, because we copy over the id from observer/fake_quant to - # dequantize node - repeated_debug_handle_ids = [1, 2, 3] - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) - - def test_extract_results_from_loggers(self) -> None: + self._assert_each_node_has_from_node_source(m) + from_node_source_map = self._extract_from_node_source(m) + res_counter = Counter(from_node_source_map.values()) + # same set of infos where repeated, because we copy over the info from observer/fake_quant to + # quantize/dequantize node + repeated_from_node_source = [ + from_node_source_map[n_name] + for n_name in node_name_equip_with_output_observer + ] + for from_node_source in repeated_from_node_source: + self.assertEqual(res_counter[from_node_source], 3) + + def test_extract_results_from_loggers(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + m_ref_logger = prepare_for_propagation_comparison(m) quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=False) ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m = prepare_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) m_quant_logger = prepare_for_propagation_comparison(m) @@ -803,29 +784,22 @@ def test_extract_results_from_loggers(self) -> None: m_quant_logger(*example_inputs) ref_results = extract_results_from_loggers(m_ref_logger) quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results( - ref_results, - quant_results, # pyre-ignore[6] - ) + comparison_results = compare_results(ref_results, quant_results) for node_summary in comparison_results.values(): if len(node_summary.results) > 0: - self.assertGreaterEqual( - node_summary.results[0].sqnr, - 35, # pyre-ignore[6] - ) + self.assertGreaterEqual(node_summary.results[0].sqnr, 35) - def test_extract_results_from_loggers_list_output(self) -> None: + def test_extract_results_from_loggers_list_output(self): m = TestHelperModules.Conv2dWithSplit() example_inputs = m.example_inputs() ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6] + m_ref_logger = prepare_for_propagation_comparison(m) quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=False) ) - m = prepare_pt2e(m, quantizer) # pyre-ignore[6] + m = prepare_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) m_quant_logger = prepare_for_propagation_comparison(m) @@ -834,10 +808,7 @@ def test_extract_results_from_loggers_list_output(self) -> None: m_quant_logger(*example_inputs) ref_results = extract_results_from_loggers(m_ref_logger) quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results( - ref_results, - quant_results, # pyre-ignore[6] - ) + comparison_results = compare_results(ref_results, quant_results) for node_summary in comparison_results.values(): if len(node_summary.results) > 0: sqnr = node_summary.results[0].sqnr @@ -845,4 +816,4 @@ def test_extract_results_from_loggers_list_output(self) -> None: for sqnr_i in sqnr: self.assertGreaterEqual(sqnr_i, 35) else: - self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6] + self.assertGreaterEqual(sqnr, 35)