In [1]:
import torch
import torch.nn.functional as F
print(torch.__version__)

2.4.0.dev20240507


In [2]:
from executorch.examples.models.llama2.builder import load_llama_model, DType
from executorch.examples.models.llama2.source_transformation.quantize import WeightOnlyInt8QuantHandler
from executorch.examples.models.llama2.source_transformation.sdpa import replace_sdpa_with_simple_sdpa

In [46]:
# stories110M

checkpoint = "stories15M.pt"
params = "params.json"
transforms = [
    lambda m: WeightOnlyInt8QuantHandler(m).quantized_model(),
    replace_sdpa_with_simple_sdpa,
]

model = load_llama_model(
    checkpoint=checkpoint,
    params_path=params,
    use_kv_cache=True,
).to_dtype(DType.fp32).source_transform(transforms)

[INFO 2024-05-29 14:04:08,422 builder.py:84] Loading model with checkpoint=stories15M.pt, params=params.json, use_kv_cache=True, weight_type=WeightType.LLAMA
[INFO 2024-05-29 14:04:08,451 builder.py:105] Loaded model with dtype=torch.float32


freqs_cos shape torch.Size([128, 24]), freqs_sin shape torch.Size([128, 24])
quantize * ('layers.0.attention.wq', Linear(in_features=288, out_features=288, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wk', Linear(in_features=288, out_features=288, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wv', Linear(in_features=288, out_features=288, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wo', Linear(in_features=288, out_features=288, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w1', Linear(in_features=288, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w2', Linear(in_features=768, out_features=288, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w3', Linear(in_features=288, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.1.attention.wq', Li

In [21]:
from executorch.backends.apple.mps.partition.mps_partitioner import (
    MPSPartitioner,
)
from executorch.exir.backend.backend_details import CompileSpec
compile_specs = [CompileSpec("use_fp16", bytes([True]))]

partitioners = [
    MPSPartitioner(compile_specs)
]

## Pass to replace aten::_weight_int8pack_mm to llama_cpp::_weight_int8pack_mm

In [22]:
from executorch.exir.pass_base import ExportPass
from executorch.examples.models.llama2.custom_ops.llama_cpp_linear import *
from executorch.exir.dialects._ops import ops as exir_ops

class ReplaceMMPass(ExportPass):

    def __init__(self):
        super().__init__()

    def call_operator(self, op, args, kwargs, meta):
        if op == exir_ops.edge.aten._weight_int8pack_mm.default:
            return super().call_operator(exir_ops.edge.llama_cpp._weight_int8pack_mm.default, args, kwargs, meta)
        else:
            return super().call_operator(op, args, kwargs, meta)



In [14]:
builder = model.export_to_edge()
builder.edge_manager = builder.edge_manager.transform([ReplaceMMPass()])
builder.to_backend(partitioners).to_executorch()

[INFO 2024-05-29 13:37:13,459 mps_partitioner.py:121] Found 25 subgraphs to be partitioned.
[INFO 2024-05-29 13:37:13,460 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,460 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,461 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,461 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,462 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,462 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,463 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,463 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 13:37:13,464 utils.

<executorch.examples.models.llama2.builder.LlamaEdgeManager at 0x32e5d10d0>

In [3]:
from executorch.examples.models.llama2.lib.partitioner_lib import get_xnnpack_partitioner

builder = model.export_to_edge(None).to_backend([get_xnnpack_partitioner()]).to_executorch().save_to_pte("stories110M_int8_xnnpack.pte")

[INFO 2024-05-08 23:48:05,952 xnnpack_partitioner.py:555] Found 85 subgraphs to be partitioned.
[INFO 2024-05-08 23:48:23,675 builder.py:341] Required memory for activation in bytes: [0, 459548672]
[INFO 2024-05-08 23:48:23,852 utils.py:113] Saved exported program to stories110M_int8_xnnpack.pte


In [3]:
builder = model.export_to_edge(None).to_executorch()

[INFO 2024-05-08 18:05:25,534 builder.py:341] Required memory for activation in bytes: [0, 418116608]


In [18]:
print(model.edge_manager._edge_programs['forward'].graph)

graph():
    %b_layers_0_attention_wq_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wq_weight]
    %b_layers_0_attention_wq_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wq_scales]
    %b_layers_0_attention_wk_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wk_weight]
    %b_layers_0_attention_wk_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wk_scales]
    %b_layers_0_attention_wv_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wv_weight]
    %b_layers_0_attention_wv_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wv_scales]
    %b_layers_0_attention_sdpa_kv_cache_k_cache : [num_users=2] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_k_cache]
    %b_layers_0_attention_sdpa_kv_cache_v_cache : [num_users=2] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_v_cache]
    %b_layers_0_attention_wo_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wo_weight

In [15]:
from executorch.exir.backend.utils import print_delegated_graph

print_delegated_graph(builder.export_program.exported_program("forward").graph_module)

graph():
  %b_layers_0_attention_wq_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wq_weight]
  %b_layers_0_attention_wq_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wq_scales]
  %b_layers_0_attention_wk_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wk_weight]
  %b_layers_0_attention_wk_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wk_scales]
  %b_layers_0_attention_wv_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wv_weight]
  %b_layers_0_attention_wv_scales : [num_users=1] = placeholder[target=b_layers_0_attention_wv_scales]
  %b_layers_0_attention_sdpa_kv_cache_k_cache : [num_users=2] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_k_cache]
  %b_layers_0_attention_sdpa_kv_cache_v_cache : [num_users=2] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_v_cache]
  %b_layers_0_attention_wo_weight : [num_users=1] = placeholder[target=b_layers_0_attention_wo_weight]
  %b_layers_0_at

In [51]:
builder.save_to_pte("stories15M_int8_mps_llama_cpp.pte")

[INFO 2024-05-29 09:33:29,995 utils.py:114] Saved exported program to stories15M_int8_mps_llama_cpp.pte


In [44]:
builder.save_to_pte("stories15M_int8_llama_cpp.pte")

[INFO 2024-05-29 09:31:26,949 utils.py:114] Saved exported program to stories15M_int8_llama_cpp.pte


In [20]:
from executorch.extension.pybindings.portable_lib import _get_operator_names

names = _get_operator_names()

In [23]:
print('\n'.join(names))
print(len(names))

aten::sym_size.int
aten::_local_scalar_dense
aten::sym_numel
executorch_prim::add.Scalar
executorch_prim::sub.Scalar
executorch_prim::mul.Scalar
executorch_prim::floordiv.Scalar
executorch_prim::truediv.Scalar
executorch_prim::eq.Scalar
executorch_prim::gt.Scalar
executorch_prim::lt.Scalar
executorch_prim::ge.Scalar
executorch_prim::le.Scalar
executorch_prim::floordiv.int
executorch_prim::et_copy_index.tensor
executorch_prim::et_view.default
aten::_cdist_forward.out
aten::_log_softmax.out
aten::_native_batch_norm_legit.out
aten::_native_batch_norm_legit.no_stats_out
aten::_native_batch_norm_legit_no_training.out
aten::_pdist_forward.out
aten::_softmax.out
aten::_to_copy.out
aten::abs.out
aten::acos.out
aten::acosh.out
aten::add.out
aten::add.Scalar_out
aten::addmm.out
aten::alias_copy.out
aten::amax.out
aten::amin.out
aten::any.all_out
aten::any.dims_out
aten::any.out
aten::arange.out
aten::arange.start_out
aten::argmax.out
aten::argmin.out
aten::as_strided_copy.out
aten::asin.out
aten

In [6]:
builder.save_to_pte("stories110M_int8.pte")

[INFO 2024-05-08 14:39:40,251 utils.py:112] Saved exported program to stories110M_int8.pte


In [7]:
builder.save_to_pte("stories110M_int8_mps.pte")

[INFO 2024-05-08 17:58:18,813 utils.py:113] Saved exported program to stories110M_int8_mps.pte


In [7]:
from executorch.extension.pybindings.portable_lib import _load_for_executorch

In [56]:
op1 = torch.ops.aten._weight_int8pack_mm.default
op2 = torch.ops.llama_cpp._weight_int8pack_mm.default
mps_device = torch.device("mps")  # Device object representing GPU.

A = torch.randn(4, 8, dtype=torch.float, device=mps_device)
B = torch.ones(8, 8, dtype=torch.int8, device=mps_device)
scales = torch.randn(8, dtype=torch.float, device=mps_device)


In [57]:
C1 = op1(A, B, scales)
print(C1)

tensor([[ 0.2965,  0.1211, -0.0488, -0.0738, -0.5035,  0.7861,  0.9290,  0.6385],
        [ 2.7348,  1.1167, -0.4501, -0.6806, -4.6442,  7.2503,  8.5685,  5.8890],
        [-0.0908, -0.0371,  0.0149,  0.0226,  0.1542, -0.2408, -0.2845, -0.1956],
        [-2.2508, -0.9191,  0.3705,  0.5601,  3.8223, -5.9672, -7.0521, -4.8468]],
       device='mps:0')


In [58]:
C2 = op2(A, B, scales)
print(C2)

tensor([[ 0.2965,  0.1211, -0.0488, -0.0738, -0.5035,  0.7861,  0.9290,  0.6385],
        [ 2.7348,  1.1167, -0.4501, -0.6806, -4.6442,  7.2503,  8.5685,  5.8890],
        [-0.0908, -0.0371,  0.0149,  0.0226,  0.1542, -0.2408, -0.2845, -0.1956],
        [-2.2508, -0.9191,  0.3705,  0.5601,  3.8223, -5.9672, -7.0521, -4.8468]],
       device='mps:0')


In [59]:
torch.allclose(C1, C2)

True

In [5]:
from sentencepiece import SentencePieceProcessor as SPP

sp_model = SPP(model_file="tokenizer.model")

In [6]:
prompt = "Once upon a time"

prompt_tokens = sp_model.encode(prompt)
prompt_tokens = [sp_model.bos_id()] + prompt_tokens
print(prompt_tokens)

[1, 9038, 2501, 263, 931]


In [44]:
t = torch.tensor([[prompt_tokens[0]]], dtype=torch.int64)
pos = torch.tensor([0], dtype=torch.int64)

model.model(t, pos)

tensor([[-5.8514,  1.4318, -5.8515,  ..., -5.8277, -5.8277, -5.8515]],
       grad_fn=<NotImplemented>)

In [47]:
# edge module
builder = model.export_to_edge()
edge_module = builder.edge_manager._edge_programs["forward"].module()

In [48]:
# custom op module
builder.edge_manager = builder.edge_manager.transform([ReplaceMMPass()])
custom_op_module = builder.edge_manager._edge_programs["forward"].module()

In [49]:
# lowered to mps module
builder.to_backend(partitioners)
mps_module = builder.edge_manager._edge_programs["forward"].module()

[INFO 2024-05-29 14:04:22,525 mps_partitioner.py:121] Found 25 subgraphs to be partitioned.
[INFO 2024-05-29 14:04:22,526 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,526 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,527 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,528 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,528 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,529 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,530 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,530 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-29 14:04:22,531 utils.

In [56]:
# to executorch
builder.to_executorch()
from executorch.extension.pybindings._portable_lib import _load_for_executorch_from_buffer

ep = _load_for_executorch_from_buffer(builder.export_program.buffer)

[program.cpp:130] InternalConsistency verification requested but not available
loc("mps_broadcast_to"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":1392:0)): error: 'anec.broadcast' op failed: input cannot be broadcasted to the target shape


loc("mps_broadcast_to"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":1392:0)): error: 'anec.broadcast' op failed: input cannot be broadcasted to the target shape


In [30]:
# eager result
print(model.model.forward(t, pos))

tensor([[-5.8514,  1.4318, -5.8515,  ..., -5.8277, -5.8277, -5.8515]],
       grad_fn=<NotImplemented>)


In [28]:
# edge result
print(edge_module.forward(t, pos))

tensor([[-5.8514,  1.4318, -5.8515,  ..., -5.8277, -5.8277, -5.8515]],
       grad_fn=<NotImplemented>)


In [31]:
# edge with custom op
print(custom_op_module.forward(t, pos))

tensor([[-5.8514,  1.4318, -5.8515,  ..., -5.8277, -5.8277, -5.8515]],
       grad_fn=<NotImplemented>)


In [32]:
# mps result
print(mps_module.forward(t, pos))

tensor([[-5.8514,  1.4318, -5.8515,  ..., -5.8277, -5.8277, -5.8515]],
       grad_fn=<NotImplemented>)


In [58]:
#ep result
print(ep.forward((t, pos)))

[tensor([[-1.5128,  3.8322, -1.5128,  ..., -1.5174, -1.5174, -1.5128]])]


In [67]:
print(builder.export_program)

<executorch.exir.program._program.ExecutorchProgramManager object at 0x35aba4cd0>
