Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ python_library(
deps = [
"fbsource//third-party/pypi/tabulate:tabulate",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
Expand Down
17 changes: 13 additions & 4 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from torch.export import export
from torch.export.exported_program import ExportedProgram

from .utils import print_ops_info


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
Expand Down Expand Up @@ -193,16 +195,17 @@ def export_to_edge(


# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution.
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
) -> EdgeProgramManager:
edge_program_manager = export_to_edge(model, inputs)
edge_prog_manager = export_to_edge(model, inputs)

# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
cadence_prog_manager = edge_prog_manager.transform(
[
InitializePipeline(),
RemoveZeroSizedCatArgsPass(),
Expand All @@ -216,4 +219,10 @@ def export_to_cadence(
]
)

return cadence_program_manager
# Print some information to terminal
print_ops_info(
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)

return cadence_prog_manager
67 changes: 14 additions & 53 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,26 @@
import tempfile

from executorch.backends.cadence.aot.ops_registrations import * # noqa
import os
from typing import Any, Tuple

from executorch.backends.cadence.aot.compiler import (
convert_pt2,
export_to_cadence,
export_to_edge,
quantize_pt2,
fuse_pt2,
)
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.runtime import runtime
from executorch.backends.cadence.runtime.executor import BundledProgramManager
from executorch.exir import ExecutorchProgramManager
from torch import nn

from .utils import print_ops_info
from .utils import save_bpte_program, save_pte_program


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def _save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
if model_name.endswith(".pte"):
filename = model_name
else:
filename = os.path.join(output_dir, f"{model_name}.pte")

try:
with open(filename, "wb") as file:
prog.write_to_file(file)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")


def _save_bpte_program(
buffer: bytes,
model_name: str,
output_dir: str = "",
) -> None:
if model_name.endswith(".bpte"):
filename = model_name
else:
filename = os.path.join(output_dir, f"{model_name}.bpte")
try:
with open(filename, "wb") as f:
f.write(buffer)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {output_dir}: {e}")


def export_model(
model: nn.Module,
example_inputs: Tuple[Any, ...],
Expand All @@ -74,32 +39,28 @@ def export_model(
working_dir = tempfile.mkdtemp(dir="/tmp")
logging.debug(f"Created work directory {working_dir}")

# convert the model (also called in quantize_pt2)
converted_model = convert_pt2(model, example_inputs, CadenceQuantizer())
# Instantiate the quantizer
quantizer = CadenceQuantizer()

# Get reference outputs from quantized_model
ref_outputs = converted_model(*example_inputs)
# Convert the model
converted_model = convert_pt2(model, example_inputs, quantizer)

# Quantize the model
quantized_model = quantize_pt2(model, example_inputs)
# Get reference outputs from converted model
ref_outputs = converted_model(*example_inputs)

# Get edge program (also called in export_to_cadence)
edge_prog_manager = export_to_edge(quantized_model, example_inputs)
# Quantize the model (note: quantizer needs to be the same as
# the one used in convert_pt2)
quantized_model = fuse_pt2(converted_model, quantizer)

# Get edge program after Cadence specific passes
cadence_prog_manager = export_to_cadence(quantized_model, example_inputs)

# Get executorch program after Cadence specific passes
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch()

logging.info("Final exported graph:\n")
exec_prog.exported_program().graph_module.graph.print_tabular()

# Print some information to terminal
print_ops_info(
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)

forward_test_data = BundledProgramManager.bundled_program_test_data_gen(
method="forward", inputs=example_inputs, expected_outputs=ref_outputs
)
Expand All @@ -110,9 +71,9 @@ def export_model(
forward_test_data,
)
# Save the program as pte (default name is CadenceDemoModel.pte)
_save_pte_program(exec_prog, file_name, working_dir)
save_pte_program(exec_prog, file_name, working_dir)
# Save the program as btpe (default name is CadenceDemoModel.bpte)
_save_bpte_program(buffer, file_name, working_dir)
save_bpte_program(buffer, file_name, working_dir)

logging.debug(
f"Executorch bundled program buffer saved to {file_name} is {len(buffer)} total bytes"
Expand Down
37 changes: 36 additions & 1 deletion backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import logging
import operator
import os
from typing import Dict, List, Tuple

import torch
from executorch.exir import memory

from executorch.exir import ExecutorchProgramManager, memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from tabulate import tabulate
Expand Down Expand Up @@ -185,3 +187,36 @@ def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
return True
return False


def save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
if model_name.endswith(".pte"):
filename = model_name
else:
filename = os.path.join(output_dir, f"{model_name}.pte")

try:
with open(filename, "wb") as file:
prog.write_to_file(file)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")


def save_bpte_program(
buffer: bytes,
model_name: str,
output_dir: str = "",
) -> None:
if model_name.endswith(".bpte"):
filename = model_name
else:
filename = os.path.join(output_dir, f"{model_name}.bpte")
try:
with open(filename, "wb") as f:
f.write(buffer)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {output_dir}: {e}")
Loading