In [3]:
import torch
import torch.nn.functional as F
print(torch.__version__)

2.4.0.dev20240505


In [4]:
from executorch.examples.models.llama2.source_transformation.quantize import dynamically_quantize_per_channel, WeightOnlyInt8Linear

In [5]:
range_min = -128
range_max = 127

weight_float = torch.randn(4, 4)
activation = torch.randn(4, 4)
weight, scales, _ = dynamically_quantize_per_channel(
    weight_float, range_min, range_max, torch.int8, None, scales_dtype=weight_float.dtype,
)

In [7]:
scales = scales.squeeze(dim=-1)

In [6]:
print(weight)

tensor([[  76,   14,  127,  -66],
        [ 127,   65,   64,  -12],
        [-128,   87,  -11,  -81],
        [ -15,  -15, -128,  -46]], dtype=torch.int8)


In [7]:
print(scales)

tensor([0.0102, 0.0142, 0.0177, 0.0196])


In [8]:
print(weight_float)

tensor([[ 0.7735,  0.1411,  1.3022, -0.6775],
        [ 1.8120,  0.9258,  0.9135, -0.1759],
        [-2.2527,  1.5385, -0.1945, -1.4349],
        [-0.2999, -0.2957, -2.5002, -0.9040]])


In [9]:
print(activation)

tensor([[-1.5251,  2.7241, -0.4740,  0.3813],
        [-0.9735, -0.1196, -1.0898,  1.3676],
        [ 0.6301,  0.2436,  0.8839, -0.4392],
        [-1.7449, -0.7084, -0.4021, -1.0763]])


In [10]:
res1 = F.linear(activation, weight.to(dtype=activation.dtype)) * scales

In [11]:
print(res1)

tensor([[-1.6662, -0.7324,  7.1829,  0.4931],
        [-3.1083, -3.0920,  0.2722,  1.8234],
        [ 1.9665,  2.2411, -0.5937, -2.0793],
        [-1.2517, -3.9859,  4.4758,  2.7017]])


In [3]:
from torch._C import DispatchKey
torch.ops.aten._weight_int8pack_mm.default.has_kernel_for_dispatch_key(DispatchKey.MPS)

True

In [14]:
res2 = torch.ops.aten._weight_int8pack_mm(activation.to(device='mps'), weight.to(device='mps'), scales=scales.to(device='mps'))
print(res2)

tensor([[-1.6662, -0.7324,  7.1829,  0.4931],
        [-3.1083, -3.0920,  0.2722,  1.8234],
        [ 1.9665,  2.2411, -0.5937, -2.0793],
        [-1.2517, -3.9859,  4.4758,  2.7017]], device='mps:0')


In [28]:
activation[2, :]

tensor([ 0.6301,  0.2436,  0.8839, -0.4392])

In [29]:
weight_float[1, :]

tensor([ 1.8120,  0.9258,  0.9135, -0.1759])

In [8]:
torch.dot(activation[2, :], weight_float[1, :])

tensor(2.5703)

In [9]:
F.linear(activation, weight_float)

tensor([[-2.1271, -1.0495, -1.1300,  0.1153],
        [ 5.1776,  3.3572,  3.4860, -0.3586],
        [ 0.8612,  2.5703, -2.0016,  0.2567],
        [-1.1237,  0.8256, -3.4993, -0.1149]])

In [1]:
from executorch.examples.models.llama2.builder import load_llama_model, DType
from executorch.examples.models.llama2.source_transformation.quantize import WeightOnlyInt8QuantHandler
from executorch.examples.models.llama2.source_transformation.sdpa import replace_sdpa_with_simple_sdpa

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# stories110M

checkpoint = "stories110M.pt"
params = "params.json"
transforms = [
    lambda m: WeightOnlyInt8QuantHandler(m).quantized_model(),
    replace_sdpa_with_simple_sdpa,
]

model = load_llama_model(
    checkpoint=checkpoint,
    params_path=params,
    use_kv_cache=True,
).to_dtype(DType.fp32).source_transform(transforms)

[INFO 2024-05-08 14:39:00,259 builder.py:84] Loading model with checkpoint=stories110M.pt, params=params.json, use_kv_cache=True, weight_type=WeightType.LLAMA
[INFO 2024-05-08 14:39:00,340 builder.py:105] Loaded model with dtype=torch.float32


quantize * ('layers.0.attention.wq', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wk', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wv', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wo', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w1', Linear(in_features=768, out_features=2048, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w2', Linear(in_features=2048, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w3', Linear(in_features=768, out_features=2048, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.1.attention.wq', Linear(in_features=768, out_features=768, bias=False)) with group_size None,

In [3]:
from executorch.backends.apple.mps.partition.mps_partitioner import (
    MPSPartitioner,
)
from executorch.exir.backend.backend_details import CompileSpec
compile_specs = [CompileSpec("use_fp16", bytes([True]))]

partitioners = [
    MPSPartitioner(compile_specs)
]

In [4]:
from executorch.backends.apple.mps.serialization import mps_graph_serialize
import inspect
print(inspect.getsource(mps_graph_serialize))

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import tempfile

import pkg_resources
from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph
from executorch.exir._serialize._dataclass import _DataclassEncoder
from executorch.exir._serialize._flatbuffer import _flatc_compile


def convert_to_flatbuffer(mps_graph: MPSGraph) -> bytes:
    mps_graph_json = json.dumps(mps_graph, cls=_DataclassEncoder)
    with tempfile.TemporaryDirectory() as d:
        schema_path = os.path.join(d, "schema.fbs")
        with open(schema_path, "wb") as schema_file:
            schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
        json_path = os.path.join(d, "schema.json")
        with open(json_path, "wb") as json_file:
            json_file.write(mps_graph_json.encode("asci

In [5]:
builder = model.export_to_edge(None).to_backend(partitioners).to_executorch()

[INFO 2024-05-08 14:39:13,238 mps_partitioner.py:120] Found 13 subgraphs to be partitioned.
[INFO 2024-05-08 14:39:13,241 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,241 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,242 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,243 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,244 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,244 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,245 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,246 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-08 14:39:13,247 utils.

In [5]:
print(model.edge_manager._edge_programs['forward'].graph)

graph():
    %b_layers_0_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_k_cache]
    %b_layers_0_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_v_cache]
    %b_layers_1_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_1_attention_sdpa_kv_cache_k_cache]
    %b_layers_1_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_1_attention_sdpa_kv_cache_v_cache]
    %b_layers_2_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_2_attention_sdpa_kv_cache_k_cache]
    %b_layers_2_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_2_attention_sdpa_kv_cache_v_cache]
    %b_layers_3_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_3_attention_sdpa_kv_cache_k_cache]
    %b_layers_3_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_l

In [25]:
from executorch.exir.backend.utils import print_delegated_graph

print_delegated_graph(builder.export_program.exported_program("forward").graph_module)

graph():
  %b_layers_0_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_k_cache]
  %b_layers_0_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_0_attention_sdpa_kv_cache_v_cache]
  %b_layers_1_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_1_attention_sdpa_kv_cache_k_cache]
  %b_layers_1_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_1_attention_sdpa_kv_cache_v_cache]
  %b_layers_2_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_2_attention_sdpa_kv_cache_k_cache]
  %b_layers_2_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_2_attention_sdpa_kv_cache_v_cache]
  %b_layers_3_attention_sdpa_kv_cache_k_cache : [num_users=4] = placeholder[target=b_layers_3_attention_sdpa_kv_cache_k_cache]
  %b_layers_3_attention_sdpa_kv_cache_v_cache : [num_users=4] = placeholder[target=b_layers_3_attentio

In [6]:
builder.save_to_pte("stories110M_int8.pte")

[INFO 2024-05-08 14:39:40,251 utils.py:112] Saved exported program to stories110M_int8.pte


In [6]:
builder.save_to_pte("stories110M_int8_mps.pte")

[INFO 2024-05-08 09:47:28,356 utils.py:112] Saved exported program to stories110M_int8_mps.pte


In [7]:
from executorch.extension.pybindings.portable_lib import _load_for_executorch

In [8]:
m = _load_for_executorch("stories110M_int8.pte")
res = m.forward((torch.tensor([[1]], dtype=torch.long), torch.tensor(0, dtype=torch.long)))

[program.cpp:130] InternalConsistency verification requested but not available

: 

: 

In [1]:
print(res)

NameError: name 'res' is not defined

In [7]:
import json
files = [
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpp7jovcm9/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpn0uo5apz/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp0i_0v_b0/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpgjvbi4s1/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpt79bbpbt/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp6ychs24g/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpqawc68v7/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpt6sq63zo/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp9d080mdv/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp6jpstr7y/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpzutwk8fg/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpp1cerjiu/schema.json",
    "/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpz4iu5_hx/schema.json",
]
nodes = []
for j in files:
    with open(j) as f:
        data = json.load(f)
        for i, node in enumerate(data["mps_nodes"]):
            if node["mpsnode_union_type"] == "MPSIndexTensor":
                print(i)
                nodes.append(node)



15
19
19
19
19
19
19
19
19
19
19
19
1
2


In [2]:
print(data.keys())

dict_keys(['version', 'mps_nodes', 'mps_values', 'input_ids', 'output_ids', 'constant_ids', 'graph_type'])


15


In [10]:
print(len(nodes))

14


In [11]:
print(nodes[0])

{'mpsnode_union': {'input1_id': 32, 'output_id': 34, 'indices_id': [-1, -1, 3]}, 'mpsnode_union_type': 'MPSIndexTensor', 'min_max': None}


In [12]:
for i, n in enumerate(nodes):
    print(i)
    print(n["mpsnode_union"]["indices_id"])

0
[-1, -1, 3]
1
[-1, -1, 5]
2
[-1, -1, 5]
3
[-1, -1, 5]
4
[-1, -1, 5]
5
[-1, -1, 5]
6
[-1, -1, 5]
7
[-1, -1, 5]
8
[-1, -1, 5]
9
[-1, -1, 5]
10
[-1, -1, 5]
11
[-1, -1, 5]
12
[1]
13
[1]


In [12]:
print(data["output_ids"])

[11, 12, 13, 15, 16, 21, 32, 46]


In [17]:
values = []
for n in custom_nodes:
    values.extend([data["mps_values"][i] for i in n["mpsnode_union"].values() if data["mps_values"][i]["datatype"] == 3])

In [19]:
print(len(values))

2


In [20]:
data["mps_values"][7]

{'datatype': 3,
 'num_dims': 1,
 'dims': [768],
 'constant_buffer_size': 1536,
 'constant_buffer': {'storage': [20,
   59,
   34,
   59,
   21,
   59,
   43,
   59,
   161,
   58,
   232,
   58,
   165,
   58,
   0,
   59,
   205,
   58,
   10,
   59,
   186,
   58,
   195,
   58,
   215,
   58,
   199,
   58,
   152,
   58,
   177,
   58,
   226,
   58,
   239,
   58,
   162,
   58,
   164,
   58,
   137,
   58,
   217,
   58,
   4,
   59,
   245,
   58,
   153,
   58,
   229,
   58,
   148,
   58,
   198,
   58,
   15,
   59,
   143,
   58,
   148,
   58,
   221,
   58,
   56,
   58,
   86,
   59,
   30,
   59,
   93,
   58,
   72,
   58,
   2,
   59,
   90,
   59,
   76,
   58,
   20,
   58,
   4,
   60,
   149,
   58,
   150,
   59,
   255,
   58,
   186,
   58,
   185,
   58,
   185,
   58,
   54,
   59,
   158,
   58,
   131,
   58,
   53,
   59,
   132,
   58,
   31,
   59,
   225,
   58,
   206,
   58,
   241,
   58,
   197,
   58,
   221,
   58,
   228,
   58,
   219,
   58,
 

In [3]:
import inspect
print(inspect.getsource(WeightOnlyInt8QuantHandler))

class WeightOnlyInt8QuantHandler(QuantHandler):
    def __init__(
        self,
        mod,
        device="cpu",
        *,
        node_type: str = "*",
        bitwidth: Optional[int] = None,
        group_size: Optional[int] = None,
    ):
        self.mod = mod
        self.group_size = group_size
        self.node_type = node_type
        if bitwidth is None:
            self.bitwidth = 8
        else:
            self.bitwidth = bitwidth

    @torch.no_grad()
    def create_quantized_state_dict(self) -> Dict:
        cur_state_dict = self.mod.state_dict()

        if self.bitwidth == 4:
            range_min = -8
            range_max = 7
        elif self.bitwidth == 8:
            range_min = -128
            range_max = 127
        else:
            raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

        for fqn, mod in self.mod.named_modules():
            # print(f"maybe? quantize {fqn}...{type(mod)}")
            if isinstance(mod, torch.nn.Linear) or isinstanc

In [6]:
import flatbuffers
from mpsgraph.MPSGraph import MPSGraph
from mpsgraph.MPSIndexTensor import MPSIndexTensor
with open("/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpudja56ka/schema.bin", 'rb') as f:
    g = MPSGraph.GetRootAsMPSGraph(f.read(), 0)
    


In [8]:
import json
nodes = []
with open("/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpudja56ka/schema.json") as f:
    data = json.load(f)
    nodes = [node for node in data["mps_nodes"]]

In [9]:
for i, node in enumerate(nodes):
    print(i, node["mpsnode_union_type"])

0 MPSCast
1 MPSCast
2 MPSCast
3 MPSCast
4 MPSCast
5 MPSCast
6 MPSSqueeze
7 MPSUnsqueeze
8 MPSUnsqueeze
9 MPSUnsqueeze
10 MPSExpand
11 MPSExpand
12 MPSPermute
13 MPSUnsqueeze
14 MPSPermute
15 MPSPermute
16 MPSPermute
17 MPSPermute
18 MPSPermute
19 MPSMatMul
20 MPSIndexTensor
21 MPSView
22 MPSView
23 MPSMul
24 MPSPermute
25 MPSExpand
26 MPSView
27 MPSExpand
28 MPSView
29 MPSView
30 MPSView
31 MPSSlice
32 MPSSlice
33 MPSSqueeze
34 MPSSqueeze
35 MPSMul
36 MPSMul
37 MPSMul
38 MPSMul
39 MPSSub
40 MPSAdd
41 MPSUnsqueeze
42 MPSUnsqueeze
43 MPSCat
44 MPSView
45 MPSPermute
46 MPSExpand
47 MPSView
48 MPSMatMul
49 MPSView
50 MPSMul
51 MPSAdd
52 MPSSoftmax
53 MPSExpand
54 MPSView
55 MPSMatMul
56 MPSView
57 MPSPermute
58 MPSView
59 MPSSqueeze
60 MPSMatMul
61 MPSMul
62 MPSAdd
63 MPSMul
64 MPSMean
65 MPSAdd
66 MPSRsqrt
67 MPSMul
68 MPSMul
69 MPSSqueeze
70 MPSSqueeze
71 MPSMatMul
72 MPSMatMul
73 MPSMul
74 MPSMul
75 MPSSigmoid
76 MPSMul
77 MPSMul
78 MPSMatMul
79 MPSMul
80 MPSAdd
81 MPSMul
82 MPSMean
83 

In [10]:
for i in range(g.MpsNodesLength()):
    print(i, g.MpsNodes(i).MpsnodeUnionType())

0 75
1 75
2 75
3 75
4 75
5 75
6 69
7 70
8 70
9 70
10 67
11 67
12 65
13 70
14 65
15 65
16 65
17 65
18 65
19 53
20 62
21 66
22 66
23 9
24 65
25 67
26 66
27 67
28 66
29 66
30 66
31 72
32 72
33 69
34 69
35 9
36 9
37 9
38 9
39 8
40 7
41 70
42 70
43 68
44 66
45 65
46 67
47 66
48 53
49 66
50 9
51 7
52 5
53 67
54 66
55 53
56 66
57 65
58 66
59 69
60 53
61 9
62 7
63 9
64 64
65 7
66 32
67 9
68 9
69 69
70 69
71 53
72 53
73 9
74 9
75 33
76 9
77 9
78 53
79 9
80 7
81 9
82 64
83 7
84 32
85 9
86 9
87 69
88 53
89 9
90 75


In [13]:
nodes[20]["mpsnode_union"]["indices_id"]

[-1, -1, 3]

In [14]:
index_tensor = g.MpsNodes(20)

In [15]:
print(index_tensor.MpsnodeUnionType())

62


In [22]:
from mpsgraph.MPSIndexTensor import MPSIndexTensor

t = MPSIndexTensor.GetRootAs(g._tab.Bytes, 20)

In [23]:
t.IndicesIdAsNumpy()

TypeError: bad number -31589593 for type uint32