From bcaa477e2aa29597f58b73a6fef6e41a42fc7e9e Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 15 Oct 2024 17:51:40 -0700 Subject: [PATCH 1/2] Update (base update) [ghstack-poisoned] From f85785eb45f11101e62d3ac25f78fe8bbff02837 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 15 Oct 2024 17:51:40 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- examples/models/llama2/runner/generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/llama2/runner/generation.py b/examples/models/llama2/runner/generation.py index 6d43c84932f..6d643b3857d 100644 --- a/examples/models/llama2/runner/generation.py +++ b/examples/models/llama2/runner/generation.py @@ -45,9 +45,9 @@ def sample_top_p(probs, p): def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + probs = torch.softmax(logits / temperature, dim=-1) return sample_top_p(probs, top_p).item() - return torch.argmax(logits[:, -1], dim=-1).item() + return torch.argmax(logits, dim=-1).item() class LlamaRunner(ABC):