In [1]:
import torch
import torch.nn.functional as F
print(torch.__version__)

2.4.0.dev20240505


In [2]:
from executorch.examples.models.llama2.source_transformation.quantize import dynamically_quantize_per_channel, WeightOnlyInt8Linear

In [4]:
range_min = -128
range_max = 127

weight_float = torch.randn(4, 4)
activation = torch.randn(4, 4)
weight, scales, _ = dynamically_quantize_per_channel(
    weight_float, range_min, range_max, torch.int8, None, scales_dtype=weight_float.dtype,
)

In [5]:
scales = scales.squeeze(dim=-1)

In [6]:
print(weight)

tensor([[  76,   14,  127,  -66],
        [ 127,   65,   64,  -12],
        [-128,   87,  -11,  -81],
        [ -15,  -15, -128,  -46]], dtype=torch.int8)


In [7]:
print(scales)

tensor([0.0102, 0.0142, 0.0177, 0.0196])


In [8]:
print(weight_float)

tensor([[ 0.7735,  0.1411,  1.3022, -0.6775],
        [ 1.8120,  0.9258,  0.9135, -0.1759],
        [-2.2527,  1.5385, -0.1945, -1.4349],
        [-0.2999, -0.2957, -2.5002, -0.9040]])


In [9]:
print(activation)

tensor([[-1.5251,  2.7241, -0.4740,  0.3813],
        [-0.9735, -0.1196, -1.0898,  1.3676],
        [ 0.6301,  0.2436,  0.8839, -0.4392],
        [-1.7449, -0.7084, -0.4021, -1.0763]])


In [10]:
res1 = F.linear(activation, weight.to(dtype=activation.dtype)) * scales

In [11]:
print(res1)

tensor([[-1.6662, -0.7324,  7.1829,  0.4931],
        [-3.1083, -3.0920,  0.2722,  1.8234],
        [ 1.9665,  2.2411, -0.5937, -2.0793],
        [-1.2517, -3.9859,  4.4758,  2.7017]])


In [3]:
from torch._C import DispatchKey
torch.ops.aten._weight_int8pack_mm.default.has_kernel_for_dispatch_key(DispatchKey.MPS)

True

In [14]:
res2 = torch.ops.aten._weight_int8pack_mm(activation.to(device='mps'), weight.to(device='mps'), scales=scales.to(device='mps'))
print(res2)

tensor([[-1.6662, -0.7324,  7.1829,  0.4931],
        [-3.1083, -3.0920,  0.2722,  1.8234],
        [ 1.9665,  2.2411, -0.5937, -2.0793],
        [-1.2517, -3.9859,  4.4758,  2.7017]], device='mps:0')


In [28]:
activation[2, :]

tensor([ 0.6301,  0.2436,  0.8839, -0.4392])

In [29]:
weight_float[1, :]

tensor([ 1.8120,  0.9258,  0.9135, -0.1759])

In [30]:
torch.dot(activation[2, :], weight_float[1, :])

tensor(2.2519)

In [27]:
F.linear(activation, weight_float)

tensor([[-1.6709, -0.7415,  7.1717,  0.4920],
        [-3.1156, -3.1108,  0.2585,  1.8158],
        [ 1.9703,  2.2519, -0.5863, -2.0738],
        [-1.2442, -3.9955,  4.4636,  2.7110]])

In [3]:
from executorch.examples.models.llama2.builder import load_llama_model, DType
from executorch.examples.models.llama2.source_transformation.quantize import WeightOnlyInt8QuantHandler
from executorch.examples.models.llama2.source_transformation.sdpa import replace_sdpa_with_simple_sdpa

In [4]:
# stories110M

checkpoint = "stories110M.pt"
params = "params.json"
transforms = [
    lambda m: WeightOnlyInt8QuantHandler(m).quantized_model(),
    replace_sdpa_with_simple_sdpa,
]

model = load_llama_model(
    checkpoint=checkpoint,
    params_path=params,
    use_kv_cache=True,
).source_transform(transforms).to_dtype(DType.fp32)

[INFO 2024-05-07 00:51:48,098 builder.py:84] Loading model with checkpoint=stories110M.pt, params=params.json, use_kv_cache=True, weight_type=WeightType.LLAMA
[INFO 2024-05-07 00:51:48,199 builder.py:105] Loaded model with dtype=torch.float32


quantize * ('layers.0.attention.wq', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wk', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wv', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.attention.wo', Linear(in_features=768, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w1', Linear(in_features=768, out_features=2048, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w2', Linear(in_features=2048, out_features=768, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.0.feed_forward.w3', Linear(in_features=768, out_features=2048, bias=False)) with group_size None, bitwidth 8
quantize * ('layers.1.attention.wq', Linear(in_features=768, out_features=768, bias=False)) with group_size None,

In [5]:
from executorch.backends.apple.mps.partition.mps_partitioner import (
    MPSPartitioner,
)
from executorch.exir.backend.backend_details import CompileSpec
compile_specs = [CompileSpec("use_fp16", bytes([True]))]

partitioners = [
    MPSPartitioner(compile_specs)
]

In [6]:
builder = model.export_to_edge(None).to_backend(partitioners).to_executorch()

[INFO 2024-05-07 00:51:58,492 mps_partitioner.py:121] Found 13 subgraphs to be partitioned.
[INFO 2024-05-07 00:51:58,494 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,495 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,496 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,496 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,498 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,498 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,499 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,500 utils.py:527] The buffer node is a mutated buffer node, which is not constant.
[INFO 2024-05-07 00:51:58,501 utils.

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp_au6tw7u


[INFO 2024-05-07 00:52:03,120 mps_preprocess.py:115] Visiting: aten_view_copy_default_225, aten.view_copy.default
[INFO 2024-05-07 00:52:03,121 mps_preprocess.py:115] Visiting: aten_view_copy_default_226, aten.view_copy.default
[INFO 2024-05-07 00:52:03,121 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_100, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:03,122 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_87, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:03,122 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_86, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:03,122 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_84, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:03,123 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_70, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:03,123 mps_preprocess.py:115] Visiting: aten_expand_copy_default_61, aten.expand_copy.default
[INFO 2024-05-07 00:52:03,123 mps_prepr

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp3rsyg4kt


[INFO 2024-05-07 00:52:04,219 mps_preprocess.py:115] Visiting: aten_view_copy_default_205, aten.view_copy.default
[INFO 2024-05-07 00:52:04,219 mps_preprocess.py:115] Visiting: aten_view_copy_default_206, aten.view_copy.default
[INFO 2024-05-07 00:52:04,220 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_90, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:04,220 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_79, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:04,220 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_78, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:04,220 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_76, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:04,221 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_63, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:04,221 mps_preprocess.py:115] Visiting: aten_expand_copy_default_55, aten.expand_copy.default
[INFO 2024-05-07 00:52:04,221 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpgrrrt1xz


[INFO 2024-05-07 00:52:05,471 mps_preprocess.py:115] Visiting: aten_view_copy_default_185, aten.view_copy.default
[INFO 2024-05-07 00:52:05,471 mps_preprocess.py:115] Visiting: aten_view_copy_default_186, aten.view_copy.default
[INFO 2024-05-07 00:52:05,471 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_80, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:05,472 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_71, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:05,472 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_70, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:05,472 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_68, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:05,472 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_56, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:05,473 mps_preprocess.py:115] Visiting: aten_expand_copy_default_49, aten.expand_copy.default
[INFO 2024-05-07 00:52:05,473 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp48ojccmy


[INFO 2024-05-07 00:52:06,675 mps_preprocess.py:115] Visiting: aten_view_copy_default_165, aten.view_copy.default
[INFO 2024-05-07 00:52:06,677 mps_preprocess.py:115] Visiting: aten_view_copy_default_166, aten.view_copy.default
[INFO 2024-05-07 00:52:06,677 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_70, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:06,678 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_63, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:06,678 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_62, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:06,678 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_60, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:06,679 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_49, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:06,679 mps_preprocess.py:115] Visiting: aten_expand_copy_default_43, aten.expand_copy.default
[INFO 2024-05-07 00:52:06,679 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpahyjskj0


[INFO 2024-05-07 00:52:08,452 mps_preprocess.py:115] Visiting: aten_view_copy_default_145, aten.view_copy.default
[INFO 2024-05-07 00:52:08,453 mps_preprocess.py:115] Visiting: aten_view_copy_default_146, aten.view_copy.default
[INFO 2024-05-07 00:52:08,453 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_60, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:08,453 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_55, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:08,454 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_54, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:08,454 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_52, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:08,454 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_42, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:08,455 mps_preprocess.py:115] Visiting: aten_expand_copy_default_37, aten.expand_copy.default
[INFO 2024-05-07 00:52:08,455 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpn4qi0bu8


[INFO 2024-05-07 00:52:10,053 mps_preprocess.py:115] Visiting: aten_view_copy_default_125, aten.view_copy.default
[INFO 2024-05-07 00:52:10,054 mps_preprocess.py:115] Visiting: aten_view_copy_default_126, aten.view_copy.default
[INFO 2024-05-07 00:52:10,054 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_50, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:10,055 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_47, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:10,055 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_46, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:10,056 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_44, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:10,056 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_35, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:10,056 mps_preprocess.py:115] Visiting: aten_expand_copy_default_31, aten.expand_copy.default
[INFO 2024-05-07 00:52:10,057 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpnji86b0d


[INFO 2024-05-07 00:52:11,713 mps_preprocess.py:115] Visiting: aten_view_copy_default_105, aten.view_copy.default
[INFO 2024-05-07 00:52:11,714 mps_preprocess.py:115] Visiting: aten_view_copy_default_106, aten.view_copy.default
[INFO 2024-05-07 00:52:11,714 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_40, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:11,715 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_39, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:11,715 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_38, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:11,715 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_36, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:11,715 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_28, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:11,716 mps_preprocess.py:115] Visiting: aten_expand_copy_default_25, aten.expand_copy.default
[INFO 2024-05-07 00:52:11,716 mps_prepro

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpy92b0kjs


[INFO 2024-05-07 00:52:13,344 mps_preprocess.py:115] Visiting: aten_view_copy_default_85, aten.view_copy.default
[INFO 2024-05-07 00:52:13,345 mps_preprocess.py:115] Visiting: aten_view_copy_default_86, aten.view_copy.default
[INFO 2024-05-07 00:52:13,345 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_30, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:13,345 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_31, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:13,346 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_30, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:13,346 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_28, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:13,347 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_21, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:13,347 mps_preprocess.py:115] Visiting: aten_expand_copy_default_19, aten.expand_copy.default
[INFO 2024-05-07 00:52:13,347 mps_preproce

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpitx_2qs5


[INFO 2024-05-07 00:52:15,008 mps_preprocess.py:115] Visiting: aten_view_copy_default_65, aten.view_copy.default
[INFO 2024-05-07 00:52:15,008 mps_preprocess.py:115] Visiting: aten_view_copy_default_66, aten.view_copy.default
[INFO 2024-05-07 00:52:15,009 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_20, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:15,009 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_23, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:15,009 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_22, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:15,009 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_20, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:15,010 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_14, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:15,010 mps_preprocess.py:115] Visiting: aten_expand_copy_default_13, aten.expand_copy.default
[INFO 2024-05-07 00:52:15,010 mps_preproce

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmp4wxym6zn


[INFO 2024-05-07 00:52:16,549 mps_preprocess.py:115] Visiting: aten_view_copy_default_45, aten.view_copy.default
[INFO 2024-05-07 00:52:16,549 mps_preprocess.py:115] Visiting: aten_view_copy_default_46, aten.view_copy.default
[INFO 2024-05-07 00:52:16,550 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims_10, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:16,550 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_15, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:16,550 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_14, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:16,550 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_12, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:16,551 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default_7, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:16,551 mps_preprocess.py:115] Visiting: aten_expand_copy_default_7, aten.expand_copy.default
[INFO 2024-05-07 00:52:16,551 mps_preprocess

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpwxden7p5


[INFO 2024-05-07 00:52:17,876 mps_preprocess.py:115] Visiting: aten_view_copy_default_25, aten.view_copy.default
[INFO 2024-05-07 00:52:17,877 mps_preprocess.py:115] Visiting: aten_view_copy_default_26, aten.view_copy.default
[INFO 2024-05-07 00:52:17,877 mps_preprocess.py:115] Visiting: aten_squeeze_copy_dims, aten.squeeze_copy.dims
[INFO 2024-05-07 00:52:17,877 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_7, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:17,878 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_6, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:17,878 mps_preprocess.py:115] Visiting: aten_unsqueeze_copy_default_4, aten.unsqueeze_copy.default
[INFO 2024-05-07 00:52:17,878 mps_preprocess.py:115] Visiting: aten__weight_int8pack_mm_default, aten._weight_int8pack_mm.default
[INFO 2024-05-07 00:52:17,878 mps_preprocess.py:115] Visiting: aten_expand_copy_default_1, aten.expand_copy.default
[INFO 2024-05-07 00:52:17,879 mps_preprocess.py:115]

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmptp8s2iqz


[INFO 2024-05-07 00:52:19,217 mps_preprocess.py:115] Visiting: aten_embedding_default, aten.embedding.default
[INFO 2024-05-07 00:52:19,218 mps_preprocess.py:115] Visiting: aten_index_tensor, aten.index.Tensor
[INFO 2024-05-07 00:52:19,218 mps_preprocess.py:115] Visiting: aten_index_tensor_1, aten.index.Tensor
[INFO 2024-05-07 00:52:19,218 mps_preprocess.py:115] Visiting: aten_mul_tensor, aten.mul.Tensor
[INFO 2024-05-07 00:52:19,219 mps_preprocess.py:115] Visiting: aten_view_copy_default_5, aten.view_copy.default
[INFO 2024-05-07 00:52:19,219 mps_preprocess.py:115] Visiting: aten_view_copy_default_6, aten.view_copy.default
[INFO 2024-05-07 00:52:19,220 mps_preprocess.py:115] Visiting: aten_mean_dim, aten.mean.dim
[INFO 2024-05-07 00:52:19,220 mps_preprocess.py:115] Visiting: aten_add_tensor, aten.add.Tensor
[INFO 2024-05-07 00:52:19,220 mps_preprocess.py:115] Visiting: aten_rsqrt_default, aten.rsqrt.default
[INFO 2024-05-07 00:52:19,221 mps_preprocess.py:115] Visiting: aten_mul_tensor

/var/folders/21/pcyct_g904x1pf8l_b2pvpy00000gn/T/tmpg747gx_c


[INFO 2024-05-07 00:52:30,531 builder.py:340] Required memory for activation in bytes: [0, 19816448]


In [7]:
print(model.edge_manager._edge_programs['forward'].graph_module)

GraphModule(
  (lowered_module_0): LoweredBackendModule()
  (lowered_module_1): LoweredBackendModule()
  (lowered_module_2): LoweredBackendModule()
  (lowered_module_3): LoweredBackendModule()
  (lowered_module_4): LoweredBackendModule()
  (lowered_module_5): LoweredBackendModule()
  (lowered_module_6): LoweredBackendModule()
  (lowered_module_7): LoweredBackendModule()
  (lowered_module_8): LoweredBackendModule()
  (lowered_module_9): LoweredBackendModule()
  (lowered_module_10): LoweredBackendModule()
  (lowered_module_11): LoweredBackendModule()
  (lowered_module_12): LoweredBackendModule()
)



def forward(self, b_layers_0_attention_sdpa_kv_cache_k_cache, b_layers_0_attention_sdpa_kv_cache_v_cache, b_layers_1_attention_sdpa_kv_cache_k_cache, b_layers_1_attention_sdpa_kv_cache_v_cache, b_layers_2_attention_sdpa_kv_cache_k_cache, b_layers_2_attention_sdpa_kv_cache_v_cache, b_layers_3_attention_sdpa_kv_cache_k_cache, b_layers_3_attention_sdpa_kv_cache_v_cache, b_layers_4_attention_sdp

In [None]:
print(builder.export_program.exported_program("forward").graph)

In [9]:
builder.save_to_pte("stories110M_int8_mps.pte")

[INFO 2024-05-07 00:43:24,641 utils.py:112] Saved exported program to stories110M_int8_mps.pte


In [10]:
from executorch.extension.pybindings.portable_lib import _load_for_executorch

In [11]:
m = _load_for_executorch("stories110M_int8_mps.pte")
res = m.forward(torch.tensor([[1]], dtype=torch.long))

[program.cpp:130] InternalConsistency verification requested but not available
[OperationUtils.mm:42] In function getMPSDataType(), assert failed (false): [ERROR] Invalid MPS data type: 3!

: 

: 