diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 3b8a71279fd..6c8a55d6220 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -68,7 +68,7 @@ def __init__( Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] ] = None, _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, - _representative_inputs: Optional[List[ProgramOutput]] = None, + _representative_inputs: Optional[List[ProgramInput]] = None, ): self.exported_program = exported_program self.export_graph_id = export_graph_id @@ -345,6 +345,56 @@ def add_edge_dialect_program( # Set the extracted data self.edge_dialect_program = processed_edge_dialect_program + def update_representative_inputs( + self, + representative_inputs: Union[List[ProgramInput], BundledProgram], + ) -> None: + """ + Update the representative inputs in the ETRecord. + + This method allows users to customize the representative inputs that will be + included when the ETRecord is saved. The representative inputs can be provided + directly as a list or extracted from a BundledProgram. + + Args: + representative_inputs: Either a list of ProgramInput objects or a BundledProgram + from which representative inputs will be extracted. + """ + if isinstance(representative_inputs, BundledProgram): + self._representative_inputs = _get_representative_inputs( + representative_inputs + ) + else: + self._representative_inputs = representative_inputs + + def update_reference_outputs( + self, + reference_outputs: Union[ + Dict[str, List[ProgramOutput]], List[ProgramOutput], BundledProgram + ], + ) -> None: + """ + Update the reference outputs in the ETRecord. + + This method allows users to customize the reference outputs that will be + included when the ETRecord is saved. The reference outputs can be provided + directly as a dictionary mapping method names to lists of outputs, as a + single list of outputs (which will be treated as {"forward": List[ProgramOutput]}), + or extracted from a BundledProgram. + + Args: + reference_outputs: Either a dictionary mapping method names to lists of + ProgramOutput objects, a single list of ProgramOutput objects (treated + as outputs for the "forward" method), or a BundledProgram from which + reference outputs will be extracted. + """ + if isinstance(reference_outputs, BundledProgram): + self._reference_outputs = _get_reference_outputs(reference_outputs) + elif isinstance(reference_outputs, list): + self._reference_outputs = {"forward": reference_outputs} + else: + self._reference_outputs = reference_outputs + def _get_reference_outputs( bundled_program: BundledProgram, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 25ea5a25e1f..dbd7fdfb776 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -10,6 +10,7 @@ import json import tempfile import unittest +from typing import List import executorch.exir.tests.models as models import torch @@ -30,6 +31,42 @@ # TODO : T154728484 Add test cases to cover multiple entry points class TestETRecord(unittest.TestCase): + def assert_representative_inputs_equal( + self, + expected_inputs: List, + actual_inputs: List, + msg: str = "Representative inputs do not match", + ) -> None: + """ + Utility function to compare representative inputs. + + This function handles the comparison of representative inputs, which are lists of tuples + containing tensors. It compares each input tuple element by element using torch.equal(). + + Args: + expected_inputs: List of expected input tuples + actual_inputs: List of actual input tuples + msg: Optional message to display on assertion failure + """ + self.assertEqual( + len(expected_inputs), + len(actual_inputs), + f"{msg}: Different number of input sets", + ) + + for i, (expected, actual) in enumerate(zip(expected_inputs, actual_inputs)): + self.assertEqual( + len(expected), + len(actual), + f"{msg}: Input set {i} has different number of tensors", + ) + + for j, (exp_tensor, act_tensor) in enumerate(zip(expected, actual)): + self.assertTrue( + torch.equal(exp_tensor, act_tensor), + f"{msg}: Tensor {j} in input set {i} does not match", + ) + def assert_etrecord_has_no_exported_program(self, etrecord: ETRecord) -> None: """Assert that ETRecord has no exported program data.""" self.assertIsNone(etrecord.exported_program) @@ -73,8 +110,7 @@ def get_test_model(self): captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) captured_output_copy = copy.deepcopy(captured_output) edge_output = captured_output.to_edge( - # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops - exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + exir.EdgeCompileConfig(_check_ir_validity=False) ) edge_output_copy = copy.deepcopy(edge_output) et_output = edge_output.to_executorch() @@ -99,8 +135,7 @@ def get_test_model_with_bundled_program(self): captured_output = exir.capture(f, inputs[0], exir.CaptureConfig()) captured_output_copy = copy.deepcopy(captured_output) edge_output = captured_output.to_edge( - # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops - exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + exir.EdgeCompileConfig(_check_ir_validity=False) ) edge_output_copy = copy.deepcopy(edge_output) et_output = edge_output.to_executorch() @@ -1230,3 +1265,283 @@ def test_add_all_programs_sequentially(self): parsed_etrecord._delegate_map, json.loads(json.dumps(et_output.delegate_map)), ) + + def test_update_representative_inputs_with_list(self): + """Test update_representative_inputs with a list of ProgramInput objects.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no representative inputs + self.assertIsNone(etrecord._representative_inputs) + + # Create custom representative inputs + f = models.BasicSinMax() + custom_inputs = [f.get_random_inputs() for _ in range(3)] + + # Update representative inputs + etrecord.update_representative_inputs(custom_inputs) + + # Verify representative inputs are now set + self.assertIsNotNone(etrecord._representative_inputs) + self.assertEqual(len(etrecord._representative_inputs), 3) + + # Compare the inputs using utility function + self.assert_representative_inputs_equal( + custom_inputs, + etrecord._representative_inputs, + "Custom inputs do not match ETRecord representative inputs", + ) + + def test_update_representative_inputs_with_bundled_program(self): + """Test update_representative_inputs with a BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + ) + + # Verify initial state - no representative inputs + self.assertIsNone(etrecord._representative_inputs) + + # Update representative inputs using bundled program + etrecord.update_representative_inputs(bundled_program) + + # Verify representative inputs are now set + self.assertIsNotNone(etrecord._representative_inputs) + + # Compare with expected inputs from bundled program using utility function + expected_inputs = _get_representative_inputs(bundled_program) + self.assert_representative_inputs_equal( + expected_inputs, + etrecord._representative_inputs, + "Bundled program inputs do not match ETRecord representative inputs", + ) + + def test_update_representative_inputs_overwrite_existing(self): + """Test that update_representative_inputs overwrites existing inputs.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance with existing representative inputs + initial_inputs = _get_representative_inputs(bundled_program) + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + _representative_inputs=initial_inputs, + ) + + # Verify initial inputs are set + self.assertIsNotNone(etrecord._representative_inputs) + + # Create new custom inputs + f = models.BasicSinMax() + new_inputs = [f.get_random_inputs() for _ in range(2)] + + # Update representative inputs with new inputs + etrecord.update_representative_inputs(new_inputs) + + # Verify inputs are updated using utility function + self.assertEqual(len(etrecord._representative_inputs), 2) + self.assert_representative_inputs_equal( + new_inputs, + etrecord._representative_inputs, + "New inputs do not match ETRecord representative inputs after overwrite", + ) + + def test_update_reference_outputs_with_dict(self): + """Test update_reference_outputs with a dictionary of outputs.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Create custom reference outputs + f = models.BasicSinMax() + inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs = { + "forward": [f.forward(*inp) for inp in inputs], + "custom_method": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], + } + + # Update reference outputs + etrecord.update_reference_outputs(custom_outputs) + + # Verify reference outputs are now set + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + self.assertIn("custom_method", etrecord._reference_outputs) + + # Compare the outputs + self.assertEqual(len(etrecord._reference_outputs["forward"]), 2) + self.assertEqual(len(etrecord._reference_outputs["custom_method"]), 2) + + for expected, actual in zip( + custom_outputs["forward"], etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + + for expected, actual in zip( + custom_outputs["custom_method"], + etrecord._reference_outputs["custom_method"], + ): + self.assertTrue(torch.equal(expected, actual)) + + def test_update_reference_outputs_with_list(self): + """Test update_reference_outputs with a single list of outputs.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Create custom reference outputs as a single list + f = models.BasicSinMax() + inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs_list = [f.forward(*inp) for inp in inputs] + + # Update reference outputs with a single list + etrecord.update_reference_outputs(custom_outputs_list) + + # Verify reference outputs are now set and treated as "forward" method + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + self.assertEqual(len(etrecord._reference_outputs["forward"]), 2) + + # Compare the outputs + for expected, actual in zip( + custom_outputs_list, etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + + def test_update_reference_outputs_with_bundled_program(self): + """Test update_reference_outputs with a BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Update reference outputs using bundled program + etrecord.update_reference_outputs(bundled_program) + + # Verify reference outputs are now set + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + + # Compare with expected outputs from bundled program + expected_outputs = _get_reference_outputs(bundled_program) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][0][0], + expected_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][1][0], + expected_outputs["forward"][1][0], + ) + ) + + def test_update_apis_and_save_parse(self): + """Test that ETRecord with updated inputs/outputs can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create custom inputs and outputs + f = models.BasicSinMax() + custom_inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs = { + "forward": [f.forward(*inp) for inp in custom_inputs], + } + + # Update both inputs and outputs + etrecord.update_representative_inputs(custom_inputs) + etrecord.update_reference_outputs(custom_outputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_custom_data.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify representative inputs are preserved using utility function + self.assertIsNotNone(parsed_etrecord._representative_inputs) + self.assertEqual(len(parsed_etrecord._representative_inputs), 2) + self.assert_representative_inputs_equal( + custom_inputs, + parsed_etrecord._representative_inputs, + "Custom inputs do not match parsed ETRecord representative inputs", + ) + + # Verify reference outputs are preserved + self.assertIsNotNone(parsed_etrecord._reference_outputs) + self.assertIn("forward", parsed_etrecord._reference_outputs) + self.assertEqual(len(parsed_etrecord._reference_outputs["forward"]), 2) + for expected, actual in zip( + custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0]))