55# LICENSE file in the root directory of this source tree.
66
77import os
8- from typing import Optional
8+ from typing import Dict , Optional
99
1010import torch
11+ import torch ._inductor
1112import torch .nn as nn
1213
1314from torch .export import Dim
14- import torch ._inductor
1515
1616from torchchat .cli .builder import (
1717 _initialize_model ,
@@ -39,6 +39,7 @@ def export_for_server(
3939 output_path : str = "model.pt2" ,
4040 dynamic_shapes : bool = False ,
4141 package : bool = True ,
42+ metadata : Optional [Dict [str , str ]] = None ,
4243) -> str :
4344 """
4445 Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
6768 dynamic_shapes = None
6869
6970 with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
70- metadata = {} # TODO: put more metadata here
71- options = {"aot_inductor.package" : package , "aot_inductor.metadata" : metadata }
71+ options = {
72+ "aot_inductor.package" : package ,
73+ "aot_inductor.metadata" : metadata or {},
74+ }
7275 if not package :
7376 options = {"aot_inductor.output_path" : output_path }
7477
@@ -81,6 +84,7 @@ def export_for_server(
8184
8285 if package :
8386 from torch ._inductor .package import package_aoti
87+
8488 path = package_aoti (output_path , path )
8589
8690 print (f"The generated packaged model can be found at: { path } " )
@@ -102,13 +106,13 @@ def export_for_server(
102106 from typing import Any , Dict , Tuple , Union
103107
104108 import executorch .exir as exir
109+ from executorch .backends .xnnpack ._passes .convert_to_linear import (
110+ ConvertToLinearPass ,
111+ )
105112
106113 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
107114 XnnpackDynamicallyQuantizedPartitioner ,
108115 )
109- from executorch .backends .xnnpack ._passes .convert_to_linear import (
110- ConvertToLinearPass ,
111- )
112116 from executorch .exir import EdgeProgramManager , to_edge
113117
114118 from executorch .exir .capture ._config import (
@@ -166,18 +170,22 @@ def __init__(self, attention: Attention):
166170
167171 self .wo = attention .wo
168172
169- max_batch_size , n_heads , max_seq_length , head_dim = (
170- attention . kv_cache [ 0 ]. k_cache . shape
171- )
173+ max_batch_size , n_heads , max_seq_length , head_dim = attention . kv_cache [
174+ 0
175+ ]. k_cache . shape
172176 cache_dtype = attention .kv_cache [0 ].k_cache .dtype
173177 # The `Attention` module being replaced can have multiple KV caches
174178 # (denoted by `cache_lanes`). Thus we follow the same setup format
175179 # as in `Attention.setup_cache`.
176180 cache_lanes = len (attention .kv_cache )
177- self .kv_cache = nn .ModuleList ([
178- CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
179- for _ in range (cache_lanes )
180- ])
181+ self .kv_cache = nn .ModuleList (
182+ [
183+ CustomKVCache (
184+ max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
185+ )
186+ for _ in range (cache_lanes )
187+ ]
188+ )
181189
182190 self .n_heads = attention .n_heads
183191 self .head_dim = attention .head_dim
@@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
215223 return self .wo (output )
216224
217225 def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
218- from executorch .extension .llm .custom_ops import ( # noqa
219- sdpa_with_kv_cache ,
220- )
226+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa
221227
222228 for name , child in module .named_children ():
223229 if isinstance (child , Attention ):
@@ -238,7 +244,9 @@ def _to_core_aten(
238244 raise ValueError (
239245 f"Expected passed in model to be an instance of fx.GraphModule, got { type (model )} "
240246 )
241- core_aten_ep = export_for_training (model , example_inputs , dynamic_shapes = dynamic_shapes )
247+ core_aten_ep = export_for_training (
248+ model , example_inputs , dynamic_shapes = dynamic_shapes
249+ )
242250 if verbose :
243251 logging .info (f"Core ATen graph:\n { core_aten_ep .graph } " )
244252 return core_aten_ep
@@ -350,7 +358,11 @@ def main(args):
350358
351359 print (f"Using device={ builder_args .device } " )
352360 set_precision (builder_args .precision )
353- set_backend (dso = args .output_dso_path , pte = args .output_pte_path , aoti_package = args .output_aoti_package_path )
361+ set_backend (
362+ dso = args .output_dso_path ,
363+ pte = args .output_pte_path ,
364+ aoti_package = args .output_aoti_package_path ,
365+ )
354366
355367 builder_args .dso_path = None
356368 builder_args .pte_path = None
@@ -372,6 +384,7 @@ def main(args):
372384
373385 # TODO: clean this up
374386 # This mess is because ET does not support _weight_int4pack_mm right now
387+ tokenizer_args = None
375388 if not builder_args .gguf_path :
376389 # tokenizer needed for quantization so get that here,
377390 try :
@@ -382,9 +395,8 @@ def main(args):
382395
383396 if builder_args .max_seq_length is None :
384397 if (
385- (output_dso_path is not None or output_aoti_package_path is not None )
386- and not builder_args .dynamic_shapes
387- ):
398+ output_dso_path is not None or output_aoti_package_path is not None
399+ ) and not builder_args .dynamic_shapes :
388400 print ("Setting max_seq_length to 300 for DSO export." )
389401 builder_args .max_seq_length = 300
390402 elif output_pte_path is not None :
@@ -397,7 +409,8 @@ def main(args):
397409 quantize ,
398410 tokenizer ,
399411 max_seq_length = builder_args .max_seq_length ,
400- support_tensor_subclass = output_dso_path is None and output_aoti_package_path is None ,
412+ support_tensor_subclass = output_dso_path is None
413+ and output_aoti_package_path is None ,
401414 )
402415 model_to_pte = model
403416 model_to_dso = model
@@ -435,7 +448,9 @@ def main(args):
435448 if output_dso_path :
436449 output_dso_path = str (os .path .abspath (output_dso_path ))
437450 print (f"Exporting model using AOT Inductor to { output_dso_path } " )
438- print ("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." )
451+ print (
452+ "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
453+ )
439454 export_for_server (
440455 model_to_dso ,
441456 builder_args .device ,
@@ -446,11 +461,23 @@ def main(args):
446461
447462 if output_aoti_package_path :
448463 output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
449- print (f"Exporting model using AOT Inductor to { output_aoti_package_path } " )
464+
465+ if tokenizer_args is None :
466+ tokenizer_type = "0"
467+ elif tokenizer_args .is_sentencepiece :
468+ tokenizer_type = "2" # Corresponding to llama2
469+ else :
470+ tokenizer_type = "3" # Corresponding to llama3
471+
472+ metadata = {"tokenizer_type" : tokenizer_type }
473+ print (
474+ "Exporting model using AOT Inductor to " f"{ output_aoti_package_path } ."
475+ )
450476 export_for_server (
451477 model_to_aoti_package ,
452478 builder_args .device ,
453479 output_aoti_package_path ,
454480 builder_args .dynamic_shapes ,
455481 package = True ,
482+ metadata = metadata ,
456483 )
0 commit comments