From 8f515452bdbeedcde8340ee3cf87167608f7758f Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 2 Sep 2025 15:18:02 -0700 Subject: [PATCH] Fix prefill_inference Summary: Fixes bugs in prefill_inference function Differential Revision: D81532886 --- examples/qualcomm/oss_scripts/llama/decoder_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 60730802233..85749232f94 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -83,7 +83,6 @@ def _model_call(self, inps): inps, self._model, self._tokenizer, - self.ar_len, self.max_seq_length, use_i64_token=self.use_i64_token, collect_logits=True, @@ -458,15 +457,13 @@ def prefill_inference( logits, new_k_caches, new_v_caches = results elif len(results) == 1: logits = results - logits = torch.argmax(logits[:, pos - 1], dim=-1).item() - token_list.append(logits) + token = torch.argmax(logits[:, pos - 1], dim=-1).item() + token_list.append(token) if collect_logits: - result_logits.append(logits) + result_logits = logits[:, :pos] pos += 1 logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}") - if collect_logits: - result_logits = torch.cat(result_logits, dim=1) return result_logits