Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
EXCLUDED_COLUMNS_WHEN_PRINTING,
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
EXCLUDED_EVENTS_WHEN_PRINTING,
find_op_names,
find_populated_event,
FORWARD,
gen_etdump_object,
Expand All @@ -68,6 +69,7 @@
from executorch.devtools.inspector.numerical_comparator import (
L1Comparator,
MSEComparator,
SNRComparator,
)
from executorch.exir import ExportedProgram

Expand Down Expand Up @@ -1084,8 +1086,6 @@ def __init__(
# Key str is method name; value is list of ProgramOutputs because of list of test cases
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
self._enable_module_hierarchy = enable_module_hierarchy
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
self._consume_etrecord()

def _consume_etrecord(self) -> None:
Expand Down Expand Up @@ -1146,19 +1146,26 @@ def _consume_etrecord(self) -> None:
event_block.reference_output = self._reference_outputs[FORWARD][
index
]
# Capture intermediate outputs only if _representative_inputs are provided
# when using bundled program to create the etrecord

def _get_aot_intermediate_outputs_and_op_names(
self,
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
"""
Capture intermediate outputs only if _representative_inputs are provided
when using bundled program to create the etrecord
"""
if self._etrecord._representative_inputs is None:
return
return {}, {}
export_program = self._etrecord.edge_dialect_program
graph_module = export_program.module()
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
graph_module
)
capturer = IntermediateOutputCapturer(graph_module)
self._aot_intermediate_outputs = capturer.run_and_capture(
aot_intermediate_outputs = capturer.run_and_capture(
self._etrecord._representative_inputs
)
return aot_intermediate_outputs, aot_debug_handle_to_op_name

# TODO: Make it more extensible to further merge overlapping debug handles
def _get_runtime_intermediate_outputs_and_op_names(
Expand Down Expand Up @@ -1366,22 +1373,27 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
pd.DataFrame: A DataFrame listing corresponding operator outputs from
both stages and their computed numerical gaps.
"""
if self._aot_intermediate_outputs is None:
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
self._get_aot_intermediate_outputs_and_op_names()
)
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0:
raise ValueError(
"The aot intermediate outputs is required but not populated."
"calculate_numerical_gap error: The aot debug information is required but not populated"
)
# The runtime_op_names will be used later to map runtime debug_handle to op_name
runtime_intermediate_outputs, runtime_op_names = (
runtime_intermediate_outputs, runtime_debug_handle_to_op_name = (
self._get_runtime_intermediate_outputs_and_op_names()
)
mapping = map_runtime_aot_intermediate_outputs(
self._aot_intermediate_outputs, runtime_intermediate_outputs
aot_intermediate_outputs, runtime_intermediate_outputs
)
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
elif metric == "L1":
comparator = L1Comparator()
elif metric == "SNR":
comparator = SNRComparator()
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

Expand All @@ -1394,9 +1406,13 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
continue
rows.append(
{
"aot_debug_handle": aot_debug_handle,
"aot_ops": find_op_names(
aot_debug_handle, aot_debug_handle_to_op_name
),
"aot_intermediate_output": aot_intermediate_output,
"runtime_debug_handle": runtime_debug_handle,
"runtime_ops": find_op_names(
runtime_debug_handle, runtime_debug_handle_to_op_name
),
"runtime_intermediate_output": runtime_intermediate_output,
"gap": comparator.compare(
aot_intermediate_output, runtime_intermediate_output
Expand Down
20 changes: 20 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,23 @@ def get_aot_debug_handle_to_op_name_mapping(
)
debug_handle_to_op_name[key] = node.name
return debug_handle_to_op_name


def find_op_names(
target_debug_handle: Tuple[int, ...],
debug_handle_to_op_name: Dict[Tuple[int, ...], str],
) -> List[str]:
"""
Record the operator names only if their debug handles are part of the target debug handle.
The debug handles in `debug_handle_to_op_name` have undergone merging and remain unchanged,
and this function identifies operations corresponding to these transformed handles.
"""
dh_set = set(target_debug_handle)
result = []

for key_tuple, op_name in debug_handle_to_op_name.items():
# Check if key is a subset of the target_debug_handle
if set(key_tuple).issubset(dh_set):
result.append(op_name)

return result
64 changes: 27 additions & 37 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TimeScale,
)
from executorch.devtools.inspector.tests.inspector_test_utils import (
check_if_debug_handle_to_op_name_match,
check_if_final_outputs_match,
model_registry,
)
Expand Down Expand Up @@ -468,25 +469,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
events=events,
)

def test_no_capture_when_representative_inputs_are_none(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
_inspector, "parse_etrecord", return_value=None
), patch.object(
_inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
_inspector, "gen_graphs_from_etrecord"
):
# Call the constructor of Inspector
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
etrecord=ETRECORD_PATH,
)
self.assertIsNone(inspector_instance._aot_intermediate_outputs)

def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
def test_etrecord_populates_correct_aot_intermediate_outputs(self):
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
etrecord_path = tmp_file.name
mod = model_registry["ConvLinearModel"]()
Expand All @@ -505,7 +488,6 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
generate_etrecord(
etrecord_path, edge_program_manager_copy, et_program_manager
)
original_consume_etrecord = Inspector._consume_etrecord
with patch.object(
Inspector, "_consume_etrecord", return_value=None
), patch.object(
Expand All @@ -529,11 +511,17 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
_representative_inputs=aten_model.example_inputs[0],
)
inspector_instance._etrecord = etrecord
Inspector._consume_etrecord = original_consume_etrecord
inspector_instance._consume_etrecord()
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
inspector_instance._get_aot_intermediate_outputs_and_op_names()
)
self.assertTrue(
check_if_final_outputs_match(
"ConvLinearModel", inspector_instance._aot_intermediate_outputs
"ConvLinearModel", aot_intermediate_outputs
)
)
self.assertTrue(
check_if_debug_handle_to_op_name_match(
"ConvLinearModel", aot_debug_handle_to_op_name
)
)

Expand Down Expand Up @@ -605,6 +593,7 @@ def test_calculate_numeric_gap(self):
), patch.object(
_inspector, "gen_graphs_from_etrecord"
):

# Call the constructor of Inspector
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
Expand All @@ -621,43 +610,44 @@ def test_calculate_numeric_gap(self):
(1,): torch.tensor([3.0, 6.0, 5.0]),
}

inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
aot_intermediate_outputs,
aot_debug_handle_to_op_name,
)
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
lambda: (runtime_intermediate_outputs, {})
lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
)

df = inspector_instance.calculate_numeric_gap(distance="L1")
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(len(df), 2)
cols = set(df.columns)
expected_cols = {
"aot_debug_handle",
"aot_ops",
"aot_intermediate_output",
"runtime_debug_handle",
"runtime_ops",
"runtime_intermediate_output",
"gap",
}
self.assertEqual(cols, expected_cols)
founded_aot_debug_handle = set(df["aot_debug_handle"])
self.assertEqual(
founded_aot_debug_handle, set(aot_intermediate_outputs.keys())
)
for _, row in df.iterrows():
aot_debuh_handle = row["aot_debug_handle"]
for i, row in df.iterrows():
# Dummpy key to get the expected aot/runtime internmediate outputs
key = (i,)
# aot_intermediate_output should equal aot_intermediate_outputs[h]
self.assertTrue(
torch.allclose(
row["aot_intermediate_output"],
aot_intermediate_outputs[aot_debuh_handle],
aot_intermediate_outputs[key],
)
)
# runtime_debug_hanlde equals aot_debug_handle at this case
self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle)
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
self.assertTrue(
torch.allclose(
row["runtime_intermediate_output"],
runtime_intermediate_outputs[aot_debuh_handle],
runtime_intermediate_outputs[key],
)
)
# gap should equal 3.0
Expand Down
38 changes: 38 additions & 0 deletions devtools/inspector/tests/inspector_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ def get_expected_intermediate_outputs():
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
}

@staticmethod
def get_expected_debug_handle_to_op_name():
"""
Returns the expected debug handle and op name mapping for this model for the given input.
"""
return {
(10,): "aten_convolution_default",
(11,): "aten_view_copy_default",
(12,): "aten_permute_copy_default",
(13,): "aten_addmm_default",
(14,): "aten_add_tensor",
(15,): "aten_sub_tensor",
(16,): "aten_mul_tensor",
(17,): "aten_add_tensor_1",
(18,): "aten_div_tensor",
(19,): "aten_relu_default",
(20,): "aten_sigmoid_default",
(21,): "aten_split_with_sizes_copy_default",
}


# Global model registry
model_registry = {
Expand Down Expand Up @@ -116,3 +136,21 @@ def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5):
return False
return True


def check_if_debug_handle_to_op_name_match(model_name, actual_debug_handle_to_op_name):
"""
Checks if the actual op names match the expected op names for the specified model.
Returns True if all match, otherwise returns False.
"""
model_instance = model_registry[model_name]
expected_debug_handle_to_op_name = (
model_instance.get_expected_debug_handle_to_op_name()
)
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
return False
for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():
actual_op_name = actual_debug_handle_to_op_name.get(debug_handle)
if actual_op_name != expected_op_name:
return False
return True
18 changes: 18 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
convert_to_float_tensor,
create_debug_handle_to_op_node_mapping,
EDGE_DIALECT_GRAPH_KEY,
find_op_names,
find_populated_event,
gen_graphs_from_etrecord,
get_aot_debug_handle_to_op_name_mapping,
Expand Down Expand Up @@ -472,6 +473,23 @@ def test_node_op_type_mismatch(self):
# Test that the filter doesn't match the mock node (op_type mismatch)
self.assertFalse(node_filter.matches(mock_node_op_type_mismatch))

def test_find_op_names_empty_debug_handle(self):
debug_handle = ()
debug_handle_to_op_name = {(1, 2): "op1", (3, 4): "op2"}
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])

def test_find_op_names_no_matching_handles(self):
debug_handle = (1, 2)
debug_handle_to_op_name = {(3, 4): "op1", (5, 6): "op2"}
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])

def test_find_op_names_matching_handles(self):
debug_handle = (1, 2, 3)
debug_handle_to_op_name = {(1, 2): "op1", (2, 3): "op2", (4, 5, 6): "op3"}
self.assertEqual(
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
)


def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
Loading