Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 159 additions & 85 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import os

import coremltools as ct
import torch
Expand All @@ -14,6 +15,7 @@
from executorch.examples.apple.coreml.llama.llama_transformer import (
InputManager,
load_model,
load_model_in_pieces_ITO,
)
from executorch.examples.apple.coreml.llama.utils import (
replace_linear_with_split_linear,
Expand All @@ -28,10 +30,9 @@

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
from torchao.utils import unwrap_tensor_subclass


def main() -> None:
def main() -> None: # noqa: C901
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
Expand Down Expand Up @@ -77,7 +78,10 @@ def main() -> None:
parser.add_argument(
"--coreml-quantize",
default=None,
choices=["b4w", "c4w"],
choices=[
"b4w",
"c4w",
],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
)
parser.add_argument(
Expand All @@ -102,67 +106,69 @@ def main() -> None:
type=str,
default="fp16",
)
parser.add_argument(
"--export_in_parts",
action="store_true",
help="Export model in 3 parts: input_block.pte, transformer_block.pte (all layers combined), output_block.pte",
)

export_args = parser.parse_args()
model = load_model(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
use_cache_list=export_args.use_cache_list,
)

float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
export_args.dtype
] # dtype for model/inputs

model.eval()
model.to(float_dtype)

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
model,
out_target_split_size=export_args.target_split_size,
out_max_splits=export_args.max_splits,
# I have not found splitting on in_features to be beneficial,
# and it often leads to OOM so I set in_max_splits to 1
in_target_split_size=1,
in_max_splits=1,
)
def maybe_split_model(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is maybe_split_model mostly just split linear or something else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It splits the linear only. But I have changes planned here.

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
model,
out_target_split_size=export_args.target_split_size,
out_max_splits=export_args.max_splits,
# I have not found splitting on in_features to be beneficial,
# and it often leads to OOM so I set in_max_splits to 1
in_target_split_size=1,
in_max_splits=1,
)

# Quantization
if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
group_size = int(group_size)
if group_size == 0:
granularity = PerAxis(0)
else:
granularity = PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bitwidth}")
def maybe_quantize_model(model):
if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
assert bitwidth in [
4,
8,
], "CoreML only supports 4-bit and 8-bit quantization"
group_size = int(group_size)
if group_size == 0:
granularity = PerAxis(0)
else:
granularity = PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bitwidth}")

quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=weight_dtype, granularity=granularity
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

if export_args.coreml_quantize == "b4w":
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
)
elif export_args.coreml_quantize == "c4w":
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
)
if export_args.coreml_quantize == "b4w":
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
)
elif export_args.coreml_quantize == "c4w":
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
)

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
Expand All @@ -179,45 +185,113 @@ def main() -> None:
skip_ops_for_coreml_delegation=[],
)

input_manager = InputManager(
n_layers=model.params.n_layers,
max_batch_size=model.params.max_batch_size,
n_kv_heads=model.params.n_kv_heads,
max_seq_length=model.params.max_seq_len,
head_dim=model.params.head_dim,
use_cache_list=export_args.use_cache_list,
seq_length=export_args.seq_length,
dtype=float_dtype,
minus_infinity=-30000,
cache_size=export_args.cache_size,
executorch_config = ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
example_inputs = input_manager.get_inputs(tokens=[0])

model = unwrap_tensor_subclass(model)
def strip_pte(name):
if name.endswith(".pte"):
return name[:-4]
else:
return name

ep = torch.export.export(model, example_inputs, strict=True)
print("Exported program")
print(ep)
if not export_args.export_in_parts:
# Mode 0: Single monolithic model
model = load_model(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
use_cache_list=export_args.use_cache_list,
)
input_manager = InputManager(
n_layers=model.params.n_layers,
max_batch_size=model.params.max_batch_size,
n_kv_heads=model.params.n_kv_heads,
max_seq_length=model.params.max_seq_len,
head_dim=model.params.head_dim,
use_cache_list=export_args.use_cache_list,
seq_length=export_args.seq_length,
dtype=float_dtype,
minus_infinity=-30000,
cache_size=export_args.cache_size,
)
example_inputs = input_manager.get_inputs(tokens=[0])
model.eval()
model = model.to(float_dtype)
print("Model", model)
maybe_split_model(model)
print("Model after split", model)
maybe_quantize_model(model)
print("Model after quantize", model)

edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
)
ep = torch.export.export(model, example_inputs, strict=True)
ep = ep.run_decompositions({})
print("Exported program")
print(ep)

edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
)

print("Delegated program")
print(format_delegated_graph(edge_manager.exported_program().graph_module))

print("Delegated program")
print(format_delegated_graph(edge_manager.exported_program().graph_module))
executorch_program = edge_manager.to_executorch(executorch_config)
filename = save_pte_program(executorch_program, export_args.output_name)
print(f"Saved Executorch program to local {filename}")

executorch_program = edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
else:
# Mode 1: Export in 3 parts with single transformer block
models, example_inputs = load_model_in_pieces_ITO(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
seq_length=export_args.seq_length,
float_dtype=float_dtype,
)
)

filename = save_pte_program(executorch_program, export_args.output_name)
print(f"Saved Executorch program to local {filename}")
for i, model in enumerate(models):
if i == 0:
ex_inputs = example_inputs[i]
suffix = "input_block"
elif i == len(models) - 1:
ex_inputs = example_inputs[-1]
suffix = "output_block"
else:
ex_inputs = example_inputs[1]
suffix = "transformer_block"

model.eval()
model = model.to(float_dtype)
print(f"Model {i}", model)
if i == len(models) - 1:
maybe_split_model(model)
print(f"Model {i} after split", model)
maybe_quantize_model(model)
print(f"Model {i} after quantize", model)
ep = torch.export.export(model, ex_inputs, strict=True)
ep = ep.run_decompositions({})
print(f"Exported program for model {i}", ep)

edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
)

print(f"Delegated program for model {i}")
print(format_delegated_graph(edge_manager.exported_program().graph_module))

executorch_program = edge_manager.to_executorch(executorch_config)
os.makedirs(f"{strip_pte(export_args.output_name)}", exist_ok=True)
filename = save_pte_program(
executorch_program,
f"{strip_pte(export_args.output_name)}/{suffix}.pte",
)
print(f"Saved Executorch program to local {filename}")


if __name__ == "__main__":
Expand Down
Loading
Loading