diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7a4b224822bd..02ea658b7f20 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -127,15 +127,6 @@ def forward_cpu( elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - # Note: this is a workaround for - # https://github.com/pytorch/pytorch/pull/151218 - @torch.compile(dynamic=True) - def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: - probs = logits.softmax(dim=-1, dtype=torch.float32) - q = torch.empty_like(probs) - q.exponential_() - return probs.div(q).argmax(dim=-1).view(-1) - if len(generators) != logits.shape[0]: return compiled_random_sample(logits), logits_to_return else: @@ -148,6 +139,16 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return +# Note: this is a workaround for +# https://github.com/pytorch/pytorch/pull/151218 +@torch.compile(dynamic=True) +def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + return probs.div(q).argmax(dim=-1).view(-1) + + def apply_top_k_top_p( logits: torch.Tensor, k: torch.Tensor | None,