Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,6 @@ def forward_cpu(
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)

if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
Expand All @@ -148,6 +139,16 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return


# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
Comment on lines +145 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While moving this function to the module level is a great performance optimization, its usage in forward_cpu reveals a pre-existing critical bug. forward_cpu calls this function when len(generators) != logits.shape[0]. This condition is met when some, but not all, requests have custom generators (0 < len(generators) < logits.shape[0]). In this scenario, compiled_random_sample is invoked, which uses the default torch generator for all requests, thereby silently ignoring the user-provided generators. This behavior is incorrect and inconsistent with random_sample, which correctly handles this mixed-generator case.

To fix this, the logic in forward_cpu should be adjusted. The most direct fix is to ensure the compiled path is only taken when no custom generators are provided, and let the un-compiled path handle all cases with generators. This would look like:

# In forward_cpu method

if not generators:
    return compiled_random_sample(logits), logits_to_return
else:
    # This logic correctly handles all cases with generators, including partial ones.
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    q = torch.empty_like(probs)
    if len(generators) != probs.shape[0]:
        # If not all requests have a generator, initialize with default.
        q.exponential_()
    # Overwrite with per-request generators where available.
    for i, generator in generators.items():
        q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

Since this bug is in code not directly modified by this PR, I recommend creating a follow-up issue or pull request to address this critical correctness issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with other sampling implementations in this file (e.g., random_sample and the else branch in forward_cpu) and to avoid an unnecessary tensor allocation, consider using the in-place div_ operation.

Suggested change
return probs.div(q).argmax(dim=-1).view(-1)
return probs.div_(q).argmax(dim=-1).view(-1)



def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,
Expand Down