Skip to content

Commit 38d22fb

Browse files
billmguofacebook-github-bot
authored andcommitted
support input_pos > 0 for prefill model
Summary: test input_pos>0 for prefill, not intention for landing but for sync with qc Differential Revision: D68847677
1 parent 92e7dbd commit 38d22fb

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,32 @@ def _kv_calibrate(
126126
else:
127127
raise RuntimeError("Unkown tokenizer")
128128

129+
129130
with torch.no_grad():
130131
while token_list[-1] != tokenizer.eos_id and pos < max_cache_len:
131132
logits, new_k_caches, new_v_caches = module(
132133
torch.full((1, 1), token_list[pos], dtype=torch.int32),
133134
atten_mask,
134-
torch.full((1, 1), pos),
135+
freq_cos,
136+
freq_sin,
135137
*k_caches,
136138
*v_caches,
137139
)
138140
atten_mask, pos, k_caches, v_caches = updator(
139141
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
140142
)
143+
k_caches = [
144+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
145+
for i, k_cache in enumerate(k_caches)
146+
]
147+
v_caches = [
148+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
149+
for i, v_cache in enumerate(v_caches)
150+
]
151+
152+
pos += 1
153+
atten_mask[0][-pos - 1] = 0
154+
print("pos", pos)
141155
if pos >= len(token_list):
142156
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
143157

@@ -206,7 +220,7 @@ def calibrate(
206220
tokenizer,
207221
max_seq_len,
208222
)
209-
elif len(example_inputs) == 5:
223+
elif len(example_inputs) == 6:
210224
_kv_calibrate(
211225
example_inputs,
212226
user_prompts,
@@ -220,18 +234,17 @@ def calibrate(
220234

221235

222236
class SingleLlama:
223-
def __init__(self, llama_model, pte_filename) -> None:
237+
def __init__(self, llama_model, pte_filename, input_len) -> None:
224238
super().__init__()
225239
self.llama_model = llama_model
226240
self.quant_dtype = None
227241
self.llama_meta = self.llama_model.get_metadata()
228242
self.has_quant_io = False
229243
self.pte_filename = pte_filename
244+
self.input_len = input_len
230245
if self.llama_meta["get_use_kv_cache"]:
231-
tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs(
232-
use_kv_cache=True
233-
)
234-
self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches)
246+
tokens, atten_mask, freq_cos, freq_sin, k_caches, v_caches = self.get_example_inputs(self.input_len)
247+
self.inputs = (tokens, atten_mask,freq_cos ,freq_sin, *k_caches, *v_caches)
235248
else:
236249
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
237250
self.inputs = (tokens, atten_mask)
@@ -346,7 +359,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
346359

347360
logging.info("Quantizing the model...")
348361
calibrate(
349-
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
362+
self.get_example_inputs(self.input_len),
350363
args.prompt,
351364
fx_graph_module,
352365
tokenizer=tokenizer,
@@ -417,8 +430,8 @@ def lowering_modules(
417430
with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
418431
exec_prog_mgr.write_to_file(file)
419432

420-
def get_example_inputs(self, use_kv_cache=True):
421-
return self.llama_model.get_example_inputs(use_kv_cache)
433+
def get_example_inputs(self, input_len):
434+
return self.llama_model.get_example_inputs(self.llama_meta, input_len)
422435

423436
def get_quant_attrs(self):
424437
return self.quant_attrs
@@ -437,7 +450,7 @@ def compile(args, pte_filename, tokenizer):
437450

438451
prefill_config = copy.copy(kv_config)
439452
prefill_config.max_seq_len = args.prefill_seq_len
440-
prefill_config.use_kv_cache = False
453+
prefill_config.use_kv_cache = True
441454

442455
state_dict = torch.load(
443456
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
@@ -451,14 +464,14 @@ def compile(args, pte_filename, tokenizer):
451464
)
452465
elif args.model_mode == "prefill":
453466
llama_instance_list.append(
454-
LlamaModel(prefill_config, output_new_cache_only=False)
467+
LlamaModel(prefill_config, output_new_cache_only=True)
455468
)
456469
elif args.model_mode == "hybrid":
457470
llama_instance_list.append(
458471
LlamaModel(kv_config, output_new_cache_only=True)
459472
)
460473
llama_instance_list.append(
461-
LlamaModel(prefill_config, output_new_cache_only=False)
474+
LlamaModel(prefill_config, output_new_cache_only=True)
462475
)
463476
else:
464477
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
@@ -506,11 +519,13 @@ def compile(args, pte_filename, tokenizer):
506519
llama_instance_list[i] = llama_instance_list[i].to(
507520
dtype_override.to_torch_dtype()
508521
)
509-
522+
510523
for i in range(len(llama_instance_list)):
511524
llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i])
525+
print(llama_instance_list[i].output_new_cache_only)
526+
seq_len = 1 if i==0 else args.prefill_seq_len
512527
llama_instance_list[i] = SingleLlama(
513-
llama_instance_list[i].eval(), pte_filename
528+
llama_instance_list[i].eval(), pte_filename, seq_len
514529
)
515530

516531
if args.ptq:
@@ -523,6 +538,7 @@ def compile(args, pte_filename, tokenizer):
523538
if args.ptq != None:
524539
kv_quant_attrs = {}
525540
for i, llama_instance in enumerate(llama_instance_list):
541+
print(f"Quantizing {i}th model")
526542
llama_instance.quantize(
527543
quant_dtype=quant_dtype,
528544
args=args,

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn.functional as F
1515
from executorch.examples.models.llama.llama_transformer import (
1616
ModelArgs,
17-
precompute_freqs_cis,
17+
Rope,
1818
)
1919

2020

@@ -309,9 +309,11 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
309309
self.n_kv_heads = config.n_kv_heads
310310
self.n_layers = config.n_layers
311311
self.vocab_size = config.vocab_size
312-
self.rope_freq_base = config.rope_freq_base
313312
self.use_kv_cache = config.use_kv_cache
314313
self.output_new_cache_only = output_new_cache_only
314+
rope = Rope(config)
315+
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
316+
self.freqs_cos, self.freqs_sin = rope.get_freqs(pos_ids, self.max_seq_len)
315317

316318
self.layers = nn.ModuleList(
317319
[
@@ -322,13 +324,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
322324
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
323325
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
324326
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
325-
freqs_cos, freqs_sin = precompute_freqs_cis(
326-
config.dim // config.n_heads,
327-
config.max_seq_len,
328-
config.rope_freq_base,
329-
)
330-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
331-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
327+
332328

333329
def prepare_output_conv(self):
334330
def forward_output_conv(x):
@@ -350,20 +346,14 @@ def forward(
350346
self,
351347
tokens: torch.Tensor,
352348
atten_mask: torch.Tensor,
353-
input_pos: Optional[torch.Tensor] = None,
349+
freqs_cos: torch.Tensor,
350+
freqs_sin: torch.Tensor,
354351
*args,
355352
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
356353

357354
output_k_cache = []
358355
output_v_cache = []
359-
# following tensors should be invariant across batches
360-
freqs_cos = (
361-
self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos[:-1]
362-
)
363-
freqs_sin = (
364-
self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin[:-1]
365-
)
366-
356+
367357
hidden_states = self.tok_embeddings(tokens)
368358
for ind, decoder_layer in enumerate(self.layers):
369359
k_caches = None
@@ -389,12 +379,13 @@ def forward(
389379

390380
return logits, output_k_cache, output_v_cache
391381

392-
def get_example_inputs(self, use_kv_cache=True):
382+
def get_example_inputs(self, llama_meta, input_len):
383+
use_kv_cache=llama_meta["get_use_kv_cache"]
393384
if use_kv_cache:
394385
tokens = torch.randint(
395386
self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32
396387
)
397-
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
388+
398389
k_cache, v_cache = [], []
399390
atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
400391
atten_mask[:, -1] = 0
@@ -418,7 +409,8 @@ def get_example_inputs(self, use_kv_cache=True):
418409
return (
419410
tokens,
420411
atten_mask,
421-
pos_ids,
412+
self.freqs_cos[:input_len],
413+
self.freqs_sin[:input_len],
422414
k_cache,
423415
v_cache,
424416
)

0 commit comments

Comments
 (0)