diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 755b82ef26b..6b6b4f583a6 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -73,6 +73,7 @@ from executorch.devtools.inspector.numerical_comparator import ( L1Comparator, MSEComparator, + NumericalComparatorBase, SNRComparator, ) from executorch.exir import ExportedProgram @@ -1404,7 +1405,9 @@ def get_exported_program( ) def calculate_numeric_gap( - self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False + self, + distance: Union[str, NumericalComparatorBase], + disable_debug_handle_valdiation: bool = False, ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) @@ -1416,7 +1419,10 @@ def calculate_numeric_gap( compare the intermediate outputs from the AOT and the runtime. Args: - distance: The metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + distance: The metrics the inspector will use for gap calculation. Can be either: + - A string: one of "MSE", "L1", or "SNR" for built-in comparators. + - A custom NumericalComparatorBase instance: allows you to define custom comparison logic + by subclassing NumericalComparatorBase and implementing the compare() method. disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR @@ -1442,15 +1448,18 @@ def calculate_numeric_gap( mapping = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - metric = distance.strip().upper() - if metric == "MSE": - comparator = MSEComparator() - elif metric == "L1": - comparator = L1Comparator() - elif metric == "SNR": - comparator = SNRComparator() + if isinstance(distance, NumericalComparatorBase): + comparator = distance else: - raise ValueError(f"Unsupported distance metric {distance!r}") + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator() + elif metric == "L1": + comparator = L1Comparator() + elif metric == "SNR": + comparator = SNRComparator() + else: + raise ValueError(f"Unsupported distance metric {distance!r}") rows = [] for (aot_debug_handle, aot_intermediate_output), ( diff --git a/devtools/inspector/numerical_comparator/__init__.py b/devtools/inspector/numerical_comparator/__init__.py index daacb5496ae..0090c50025f 100644 --- a/devtools/inspector/numerical_comparator/__init__.py +++ b/devtools/inspector/numerical_comparator/__init__.py @@ -13,9 +13,13 @@ MSEComparator, ) +from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + NumericalComparatorBase, +) + from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import ( SNRComparator, ) -__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"] +__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"] diff --git a/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py b/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py deleted file mode 100644 index b6dac7e1970..00000000000 --- a/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -from abc import ABC, abstractmethod -from typing import Any - - -class InspectorNumericalComparatorBase(ABC): - @abstractmethod - def compare(self, a: Any, b: Any) -> float: - """Compare two intermediate output and return a result. - - This method should be overridden by subclasses to provide custom comparison logic. - - Args: - a: The first intermediate output to compare. - b: The second intermediate output to compare. - - Returns: - A numerical result indicating the comparison outcome. - """ - pass diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index cfea930c20e..93a74915e84 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -722,6 +722,79 @@ def test_calculate_numeric_gap(self): # gap should equal 3.0 self.assertEqual(row["gap"][0], 3.0) + def test_calculate_numeric_gap_with_custom_comparator(self): + """Test calculate_numeric_gap with a custom NumericalComparatorBase implementation.""" + from executorch.devtools.inspector.numerical_comparator import ( + NumericalComparatorBase, + ) + + # Create a custom comparator that returns the max absolute difference + class MaxAbsDiffComparator(NumericalComparatorBase): + def compare(self, a, b): + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.max(torch.abs(a - b)).item() + return abs(a - b) + + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 1.0, 5.0])], 1), + (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Create custom comparator instance + custom_comparator = MaxAbsDiffComparator() + + # Test with custom comparator + df = inspector_instance.calculate_numeric_gap(distance=custom_comparator) + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_ops", + "aot_intermediate_output", + "runtime_ops", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + + # Verify the custom comparator logic + # For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0 + self.assertEqual(df.iloc[0]["gap"][0], 2.0) + # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0 + self.assertEqual(df.iloc[1]["gap"][0], 1.0) + @unittest.skip("ci config values are not propagated") def test_intermediate_tensor_comparison_with_torch_export(self): """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower."""