diff --git a/generate.py b/generate.py index dc0921ebb..735ee6e43 100644 --- a/generate.py +++ b/generate.py @@ -92,7 +92,7 @@ def multinomial_sample_one_no_sync( def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) + logits = logits / max(temperature, 1e-5 if logits.dtype != torch.float16 else 1e-3) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1)))