@@ -126,18 +126,32 @@ def _kv_calibrate(
126126 else :
127127 raise RuntimeError ("Unkown tokenizer" )
128128
129+
129130 with torch .no_grad ():
130131 while token_list [- 1 ] != tokenizer .eos_id and pos < max_cache_len :
131132 logits , new_k_caches , new_v_caches = module (
132133 torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
133134 atten_mask ,
134- torch .full ((1 , 1 ), pos ),
135+ freq_cos ,
136+ freq_sin ,
135137 * k_caches ,
136138 * v_caches ,
137139 )
138140 atten_mask , pos , k_caches , v_caches = updator (
139141 atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
140142 )
143+ k_caches = [
144+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
145+ for i , k_cache in enumerate (k_caches )
146+ ]
147+ v_caches = [
148+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
149+ for i , v_cache in enumerate (v_caches )
150+ ]
151+
152+ pos += 1
153+ atten_mask [0 ][- pos - 1 ] = 0
154+ print ("pos" , pos )
141155 if pos >= len (token_list ):
142156 token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
143157
@@ -206,7 +220,7 @@ def calibrate(
206220 tokenizer ,
207221 max_seq_len ,
208222 )
209- elif len (example_inputs ) == 5 :
223+ elif len (example_inputs ) == 6 :
210224 _kv_calibrate (
211225 example_inputs ,
212226 user_prompts ,
@@ -220,18 +234,17 @@ def calibrate(
220234
221235
222236class SingleLlama :
223- def __init__ (self , llama_model , pte_filename ) -> None :
237+ def __init__ (self , llama_model , pte_filename , input_len ) -> None :
224238 super ().__init__ ()
225239 self .llama_model = llama_model
226240 self .quant_dtype = None
227241 self .llama_meta = self .llama_model .get_metadata ()
228242 self .has_quant_io = False
229243 self .pte_filename = pte_filename
244+ self .input_len = input_len
230245 if self .llama_meta ["get_use_kv_cache" ]:
231- tokens , atten_mask , pos_ids , k_caches , v_caches = self .get_example_inputs (
232- use_kv_cache = True
233- )
234- self .inputs = (tokens , atten_mask , pos_ids , * k_caches , * v_caches )
246+ tokens , atten_mask , freq_cos , freq_sin , k_caches , v_caches = self .get_example_inputs (self .input_len )
247+ self .inputs = (tokens , atten_mask ,freq_cos ,freq_sin , * k_caches , * v_caches )
235248 else :
236249 tokens , atten_mask = self .get_example_inputs (use_kv_cache = False )
237250 self .inputs = (tokens , atten_mask )
@@ -346,7 +359,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
346359
347360 logging .info ("Quantizing the model..." )
348361 calibrate (
349- self .get_example_inputs (self .llama_meta [ "get_use_kv_cache" ] ),
362+ self .get_example_inputs (self .input_len ),
350363 args .prompt ,
351364 fx_graph_module ,
352365 tokenizer = tokenizer ,
@@ -417,8 +430,8 @@ def lowering_modules(
417430 with open (f"{ work_space } /{ self .pte_filename } .pte" , "wb" ) as file :
418431 exec_prog_mgr .write_to_file (file )
419432
420- def get_example_inputs (self , use_kv_cache = True ):
421- return self .llama_model .get_example_inputs (use_kv_cache )
433+ def get_example_inputs (self , input_len ):
434+ return self .llama_model .get_example_inputs (self . llama_meta , input_len )
422435
423436 def get_quant_attrs (self ):
424437 return self .quant_attrs
@@ -437,7 +450,7 @@ def compile(args, pte_filename, tokenizer):
437450
438451 prefill_config = copy .copy (kv_config )
439452 prefill_config .max_seq_len = args .prefill_seq_len
440- prefill_config .use_kv_cache = False
453+ prefill_config .use_kv_cache = True
441454
442455 state_dict = torch .load (
443456 args .checkpoint , weights_only = True , map_location = "cpu" , mmap = True
@@ -451,14 +464,14 @@ def compile(args, pte_filename, tokenizer):
451464 )
452465 elif args .model_mode == "prefill" :
453466 llama_instance_list .append (
454- LlamaModel (prefill_config , output_new_cache_only = False )
467+ LlamaModel (prefill_config , output_new_cache_only = True )
455468 )
456469 elif args .model_mode == "hybrid" :
457470 llama_instance_list .append (
458471 LlamaModel (kv_config , output_new_cache_only = True )
459472 )
460473 llama_instance_list .append (
461- LlamaModel (prefill_config , output_new_cache_only = False )
474+ LlamaModel (prefill_config , output_new_cache_only = True )
462475 )
463476 else :
464477 raise RuntimeError (f"Unknown model_mode: { args .model_mode } ." )
@@ -506,11 +519,13 @@ def compile(args, pte_filename, tokenizer):
506519 llama_instance_list [i ] = llama_instance_list [i ].to (
507520 dtype_override .to_torch_dtype ()
508521 )
509-
522+
510523 for i in range (len (llama_instance_list )):
511524 llama_instance_list [i ] = convert_linear_to_conv2d (llama_instance_list [i ])
525+ print (llama_instance_list [i ].output_new_cache_only )
526+ seq_len = 1 if i == 0 else args .prefill_seq_len
512527 llama_instance_list [i ] = SingleLlama (
513- llama_instance_list [i ].eval (), pte_filename
528+ llama_instance_list [i ].eval (), pte_filename , seq_len
514529 )
515530
516531 if args .ptq :
@@ -523,6 +538,7 @@ def compile(args, pte_filename, tokenizer):
523538 if args .ptq != None :
524539 kv_quant_attrs = {}
525540 for i , llama_instance in enumerate (llama_instance_list ):
541+ print (f"Quantizing { i } th model" )
526542 llama_instance .quantize (
527543 quant_dtype = quant_dtype ,
528544 args = args ,
0 commit comments