diff --git a/examples/models/llama2/runner/generation.py b/examples/models/llama2/runner/generation.py index 6d43c84932f..6d643b3857d 100644 --- a/examples/models/llama2/runner/generation.py +++ b/examples/models/llama2/runner/generation.py @@ -45,9 +45,9 @@ def sample_top_p(probs, p): def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + probs = torch.softmax(logits / temperature, dim=-1) return sample_top_p(probs, top_p).item() - return torch.argmax(logits[:, -1], dim=-1).item() + return torch.argmax(logits, dim=-1).item() class LlamaRunner(ABC):