Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 66 additions & 32 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import print_delegation_info

from executorch.devtools.etrecord import generate_etrecord
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
from executorch.examples.models.llama.hf_download import (
download_and_convert_hf_checkpoint,
)
Expand Down Expand Up @@ -749,7 +749,9 @@ def _to_edge_and_lower_llama_xnnpack(
pt2e_quant_params,
quantizers,
quant_dtype,
args,
xnnpack_extended_ops: bool = False,
generate_etrecord: bool = False,
verbose: bool = False,
) -> LLMEdgeManager: # noqa: C901
partitioners = []

Expand All @@ -758,7 +760,7 @@ def _to_edge_and_lower_llama_xnnpack(

modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack_extended_ops:
if xnnpack_extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
Expand All @@ -769,15 +771,15 @@ def _to_edge_and_lower_llama_xnnpack(
logging.info(f"--> {partitioner.__class__.__name__}")

# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
if args.generate_etrecord:
if generate_etrecord:
raise NotImplementedError(
"export_llama does not support XNNPack and generating ETRecord at the moment."
)

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)
if args.verbose:
if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)

return builder.to_executorch(passes=additional_passes)
Expand All @@ -790,52 +792,66 @@ def _to_edge_and_lower_llama( # noqa: C901
pt2e_quant_params,
quantizers,
quant_dtype,
args,
vulkan: bool = False,
mps: bool = False,
coreml: bool = False,
qnn: bool = False,
dtype_override: str = "fp32",
enable_dynamic_shape: bool = True,
use_kv_cache: bool = False,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
coreml_ios: int = 15,
coreml_quantize: Optional[str] = None,
coreml_compute_units: str = "cpu_only",
use_qnn_sha: bool = False,
num_sharding: int = 0,
soc_model: str = "SM8650",
generate_etrecord: bool = False,
verbose: bool = False,
):
builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()

# to_backend
partitioners = []
if args.vulkan:
if vulkan:
partitioners.append(
get_vulkan_partitioner(
args.dtype_override,
args.enable_dynamic_shape,
dtype_override,
enable_dynamic_shape,
)
)
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())

if args.mps:
partitioners.append(get_mps_partitioner(args.use_kv_cache))
if mps:
partitioners.append(get_mps_partitioner(use_kv_cache))
modelname = f"mps_{modelname}"

if args.coreml:
if coreml:
coreml_partitioner = get_coreml_partitioner(
args.coreml_ios,
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
args.coreml_compute_units,
coreml_ios,
embedding_quantize,
pt2e_quantize,
coreml_quantize,
coreml_compute_units,
)
partitioners.append(coreml_partitioner)
modelname = f"coreml_{modelname}"

if args.qnn:
if qnn:
logging.warning(
"The model definition in current repro is not performant, please refer to the instruction"
" in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
)
from executorch.extension.llm.custom_ops import model_sharding

partitioners.append(
get_qnn_partitioner(
args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
)
get_qnn_partitioner(use_kv_cache, pt2e_quantize, num_sharding, soc_model)
)
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
from executorch.backends.qualcomm._passes import (
Expand Down Expand Up @@ -864,7 +880,7 @@ def _to_edge_and_lower_llama( # noqa: C901
)

atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
if use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
)
Expand All @@ -887,10 +903,10 @@ def _to_edge_and_lower_llama( # noqa: C901
passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
"get_quant_io_dtype_fn"
] = partial(get_custom_quant_ios_dtype, cache_shape)
if args.num_sharding > 0:
if num_sharding > 0:
SplitGraph, setting = model_sharding.get_split_graph_pass(
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
shares=num_sharding,
)
passes_job[SplitGraph] = setting
dep_table[SplitGraph] = [FoldQDQ]
Expand All @@ -905,17 +921,17 @@ def _to_edge_and_lower_llama( # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

if args.generate_etrecord:
if generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")

logging.info("Generating etrecord")
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
if num_sharding > 0 and qnn:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
from executorch.backends.qualcomm.utils.utils import canonicalize_program

Expand All @@ -927,17 +943,17 @@ def _to_edge_and_lower_llama( # noqa: C901

# Generate ETRecord
if edge_manager_copy:
generate_etrecord(
generate_etrecord_func(
et_record="etrecord.bin",
edge_dialect_program=edge_manager_copy,
executorch_program=builder.export_program,
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
if num_sharding > 0 and qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

canonicalize_program(builder.edge_manager.exported_program())
Expand Down Expand Up @@ -976,7 +992,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
pt2e_quant_params,
quantizers,
quant_dtype,
args,
xnnpack_extended_ops=args.xnnpack_extended_ops,
generate_etrecord=args.generate_etrecord,
verbose=args.verbose,
)
else:
builder = _to_edge_and_lower_llama(
Expand All @@ -986,7 +1004,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
pt2e_quant_params,
quantizers,
quant_dtype,
args,
vulkan=args.vulkan,
mps=args.mps,
coreml=args.coreml,
qnn=args.qnn,
dtype_override=args.dtype_override,
enable_dynamic_shape=args.enable_dynamic_shape,
use_kv_cache=args.use_kv_cache,
embedding_quantize=args.embedding_quantize,
pt2e_quantize=args.pt2e_quantize,
coreml_ios=args.coreml_ios,
coreml_quantize=args.coreml_quantize,
coreml_compute_units=args.coreml_compute_units,
use_qnn_sha=args.use_qnn_sha,
num_sharding=args.num_sharding,
soc_model=args.soc_model,
generate_etrecord=args.generate_etrecord,
verbose=args.verbose,
)

if args.profile_memory:
Expand Down
Loading