Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 156 additions & 45 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,151 @@ def _save_edge_dialect_program(
f"{base_name}_example_inputs", serialized_artifact.example_inputs
)

def add_extra_export_modules(
self,
extra_recorded_export_modules: Dict[
str,
Union[
ExportedProgram,
ExirExportedProgram,
EdgeProgramManager,
],
],
) -> None:
"""
Add extra export modules to the ETRecord after it has been created.

This method allows users to add more export modules they want to record
to an existing ETRecord instance. The modules will be added to the graph_map
and will be included when the ETRecord is saved.

Args:
extra_recorded_export_modules: A dictionary of graph modules with the key being
the user provided name and the value being the corresponding exported module.
The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
"""
if self.graph_map is None:
self.graph_map = {}

# Now self.graph_map is guaranteed to be non-None
graph_map = self.graph_map
for module_name, export_module in extra_recorded_export_modules.items():
_add_module_to_graph_map(graph_map, module_name, export_module)

def add_executorch_program(
self,
executorch_program: Union[
ExecutorchProgram,
ExecutorchProgramManager,
BundledProgram,
],
) -> None:
"""
Add executorch program data to the ETRecord after it has been created.

This method allows users to add executorch program data they want to record
to an existing ETRecord instance. The executorch program data includes debug handle map,
delegate map, reference outputs, and representative inputs that will be included
when the ETRecord is saved.

Args:
executorch_program: The ExecuTorch program for this model returned by the call to
`to_executorch()` or the `BundledProgram` of this model.

Raises:
RuntimeError: If executorch program data already exists in the ETRecord.
"""
# Check if executorch program data already exists
if (
self._debug_handle_map is not None
or self._delegate_map is not None
or self._reference_outputs is not None
or self._representative_inputs is not None
):
raise RuntimeError(
"Executorch program data already exists in the ETRecord. "
"Cannot add executorch program data when it already exists."
)

# Process executorch program and extract data
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
_process_executorch_program(executorch_program)
)

# Set the extracted data
self._debug_handle_map = debug_handle_map
self._delegate_map = delegate_map
self._reference_outputs = reference_outputs
self._representative_inputs = representative_inputs

def add_exported_program(
self,
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]],
) -> None:
"""
Add exported program to the ETRecord after it has been created.

This method allows users to add an exported program they want to record
to an existing ETRecord instance. The exported program will be included
when the ETRecord is saved.

Args:
exported_program: The exported program for this model returned by the call to
`torch.export()` or a dictionary with method names as keys and exported programs as values.
Can be None, in which case no exported program data will be added.

Raises:
RuntimeError: If exported program already exists in the ETRecord.
"""
# Check if exported program already exists
if self.exported_program is not None or self.export_graph_id is not None:
raise RuntimeError(
"Exported program already exists in the ETRecord. "
"Cannot add exported program when it already exists."
)

# Process exported program and extract data
processed_exported_program, export_graph_id = _process_exported_program(
exported_program
)

# Set the extracted data
self.exported_program = processed_exported_program
self.export_graph_id = export_graph_id

def add_edge_dialect_program(
self,
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
) -> None:
"""
Add edge dialect program to the ETRecord after it has been created.

This method allows users to add an edge dialect program they want to record
to an existing ETRecord instance. The edge dialect program will be included
when the ETRecord is saved.

Args:
edge_dialect_program: The edge dialect program for this model returned by the call to
`to_edge()` or `EdgeProgramManager` for this model.

Raises:
RuntimeError: If edge dialect program already exists in the ETRecord.
"""
# Check if edge dialect program already exists
if self.edge_dialect_program is not None:
raise RuntimeError(
"Edge dialect program already exists in the ETRecord. "
"Cannot add edge dialect program when it already exists."
)

# Process edge dialect program and extract data
processed_edge_dialect_program = _process_edge_dialect_program(
edge_dialect_program
)

# Set the extracted data
self.edge_dialect_program = processed_edge_dialect_program


def _get_reference_outputs(
bundled_program: BundledProgram,
Expand Down Expand Up @@ -285,37 +430,24 @@ def generate_etrecord(
Returns:
None
"""
# Process all inputs and prepare data for ETRecord construction
processed_exported_program, export_graph_id = _process_exported_program(
exported_program
)
graph_map = _process_extra_recorded_modules(extra_recorded_export_modules)
processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program)
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
_process_executorch_program(executorch_program)
)
etrecord = ETRecord()
etrecord.add_exported_program(exported_program)
etrecord.add_edge_dialect_program(edge_dialect_program)
etrecord.add_executorch_program(executorch_program)

# Create ETRecord instance and save
etrecord = ETRecord(
exported_program=processed_exported_program,
export_graph_id=export_graph_id,
edge_dialect_program=processed_edge_dialect_program,
graph_map=graph_map if graph_map else None,
_debug_handle_map=debug_handle_map,
_delegate_map=delegate_map,
_reference_outputs=reference_outputs,
_representative_inputs=representative_inputs,
)
# Add extra export modules if user provided
if extra_recorded_export_modules is not None:
etrecord.add_extra_export_modules(extra_recorded_export_modules)

etrecord.save(et_record)


def _process_exported_program(
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]]
) -> tuple[Optional[ExportedProgram], int]:
) -> tuple[Optional[ExportedProgram], Optional[int]]:
"""Process exported program and return the processed program and export graph id."""
processed_exported_program = None
export_graph_id = 0
export_graph_id = None

if exported_program is not None:
if isinstance(exported_program, dict) and "forward" in exported_program:
Expand All @@ -329,29 +461,6 @@ def _process_exported_program(
return processed_exported_program, export_graph_id


def _process_extra_recorded_modules(
extra_recorded_export_modules: Optional[
Dict[
str,
Union[
ExportedProgram,
ExirExportedProgram,
EdgeProgramManager,
],
]
]
) -> Dict[str, ExportedProgram]:
"""Process extra recorded export modules and return graph map."""
graph_map = {}

if extra_recorded_export_modules is not None:
for module_name, export_module in extra_recorded_export_modules.items():
_validate_module_name(module_name)
_add_module_to_graph_map(graph_map, module_name, export_module)

return graph_map


def _validate_module_name(module_name: str) -> None:
"""Validate that module name is not a reserved name."""
contains_reserved_name = any(
Expand All @@ -369,6 +478,8 @@ def _add_module_to_graph_map(
export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager],
) -> None:
"""Add export module to graph map based on its type."""
_validate_module_name(module_name)

if isinstance(export_module, ExirExportedProgram):
graph_map[f"{module_name}/forward"] = export_module.exported_program
elif isinstance(export_module, ExportedProgram):
Expand Down
Loading
Loading