Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from executorch.devtools.inspector.numerical_comparator import (
L1Comparator,
MSEComparator,
NumericalComparatorBase,
SNRComparator,
)
from executorch.exir import ExportedProgram
Expand Down Expand Up @@ -1404,7 +1405,9 @@ def get_exported_program(
)

def calculate_numeric_gap(
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
self,
distance: Union[str, NumericalComparatorBase],
disable_debug_handle_valdiation: bool = False,
):
"""
Compares logged intermediate outputs from the exported graph (in ETRecord)
Expand All @@ -1416,7 +1419,10 @@ def calculate_numeric_gap(
compare the intermediate outputs from the AOT and the runtime.

Args:
distance: The metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
distance: The metrics the inspector will use for gap calculation. Can be either:
- A string: one of "MSE", "L1", or "SNR" for built-in comparators.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be - A string: one of "MSE", "L1", "SNR", or user-defined comparators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i separated the bulit-in comparators and custom comparators into two lines

- A custom NumericalComparatorBase instance: allows you to define custom comparison logic
by subclassing NumericalComparatorBase and implementing the compare() method.
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
Expand All @@ -1442,15 +1448,18 @@ def calculate_numeric_gap(
mapping = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
elif metric == "L1":
comparator = L1Comparator()
elif metric == "SNR":
comparator = SNRComparator()
if isinstance(distance, NumericalComparatorBase):
comparator = distance
else:
raise ValueError(f"Unsupported distance metric {distance!r}")
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
elif metric == "L1":
comparator = L1Comparator()
elif metric == "SNR":
comparator = SNRComparator()
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

rows = []
for (aot_debug_handle, aot_intermediate_output), (
Expand Down
6 changes: 5 additions & 1 deletion devtools/inspector/numerical_comparator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
MSEComparator,
)

from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
NumericalComparatorBase,
)

from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import (
SNRComparator,
)


__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"]
__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"]

This file was deleted.

73 changes: 73 additions & 0 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,79 @@ def test_calculate_numeric_gap(self):
# gap should equal 3.0
self.assertEqual(row["gap"][0], 3.0)

def test_calculate_numeric_gap_with_custom_comparator(self):
"""Test calculate_numeric_gap with a custom NumericalComparatorBase implementation."""
from executorch.devtools.inspector.numerical_comparator import (
NumericalComparatorBase,
)

# Create a custom comparator that returns the max absolute difference
class MaxAbsDiffComparator(NumericalComparatorBase):
def compare(self, a, b):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return torch.max(torch.abs(a - b)).item()
return abs(a - b)

# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
_inspector, "parse_etrecord", return_value=None
), patch.object(
_inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
_inspector, "gen_graphs_from_etrecord"
):
# Call the constructor of Inspector
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
etrecord=ETRECORD_PATH,
)

aot_intermediate_outputs = {
(0,): torch.tensor([1.0, 2.0, 3.0]),
(1,): torch.tensor([4.0, 5.0, 6.0]),
}

runtime_intermediate_outputs = {
(0,): ([torch.tensor([2.0, 1.0, 5.0])], 1),
(1,): ([torch.tensor([3.0, 6.0, 5.0])], 1),
}

aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
aot_intermediate_outputs,
aot_debug_handle_to_op_name,
)
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
)

# Create custom comparator instance
custom_comparator = MaxAbsDiffComparator()

# Test with custom comparator
df = inspector_instance.calculate_numeric_gap(distance=custom_comparator)
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(len(df), 2)
cols = set(df.columns)
expected_cols = {
"aot_ops",
"aot_intermediate_output",
"runtime_ops",
"runtime_intermediate_output",
"gap",
}
self.assertEqual(cols, expected_cols)

# Verify the custom comparator logic
# For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0
self.assertEqual(df.iloc[0]["gap"][0], 2.0)
# For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0
self.assertEqual(df.iloc[1]["gap"][0], 1.0)

@unittest.skip("ci config values are not propagated")
def test_intermediate_tensor_comparison_with_torch_export(self):
"""Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower."""
Expand Down
Loading