Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#include "../core/batch_invariant.hpp"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -406,8 +405,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
Expand Down
65 changes: 37 additions & 28 deletions tests/v1/generation/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed.
- Keep max_tokens and max_model_len bounded for speed and memory use.
"""
random.seed(12345)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)

# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."

# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))

# Sampling parameters: longer outputs with a more random-sounding
Expand All @@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with bs=1 behavior
llm_bs1 = LLM_with_max_seqs(
model=model,
max_num_seqs=1,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
Expand All @@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with larger batch limit (e.g., 64)
llm_bsN = LLM_with_max_seqs(
model=model,
max_num_seqs=batch_size,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
Expand All @@ -135,15 +138,16 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
mismatches = 0

for trial in range(num_trials):
# Create a batch of size `batch_size` and insert the needle at
# Create a batch of size `max_batch_size` and insert the needle at
# a random index
prompts: list[str] = []
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1)
for i in range(batch_size):
if i == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(_random_prompt())
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))

# Generate with the larger-batch engine
outputs = llm_bsN.generate(prompts, sampling)
Expand All @@ -154,19 +158,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
text = needle_output.outputs[0].text

if text != baseline_text:
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
mismatches += 1

passes = num_trials - mismatches
# Dump how many passed vs failed
print(
f"[determinism] total={num_trials}, passed={passes}, "
f"failed={mismatches}, batch_size={batch_size}"
f"failed={mismatches}, max_batch_size={max_batch_size}"
)

if mismatches > 0:
pytest.fail(
f"Nondeterministic outputs detected: {mismatches} failed out "
f"of {num_trials} trials (batch_size={batch_size})."
f"of {num_trials} trials (max_batch_size={max_batch_size})."
)

finally:
Expand Down Expand Up @@ -199,25 +204,28 @@ def _extract_step_logprobs(request_output):
not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.",
)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
# model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend

seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

# Force float32 to avoid precision-induced differences.
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enforce_eager=True, # helps reduce nondeterminism from some backends
enforce_eager=True,
enable_prefix_caching=False,
)

prompts = [
"The capital of France is",
"The capital of Germany is",
]
prompts = [_random_prompt(10, 1024) for i in range(100)]

sp = SamplingParams(
temperature=0.0,
temperature=0.6,
top_p=1.0,
max_tokens=8,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
Expand All @@ -238,29 +246,29 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
)
bs1_logprobs_per_prompt.append(step_logprobs)

# BS=2: run prompts in a batch and collect logprobs per step for each
# BS=N: run prompts in a batch and collect logprobs per step for each
# prompt.
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bs2_logprobs_per_prompt = []
bsN_logprobs_per_prompt = []
for o in outs_batched:
step_logprobs = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs2_logprobs_per_prompt.append(step_logprobs)
bsN_logprobs_per_prompt.append(step_logprobs)

# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
):
assert len(logprobs_bs1) == len(logprobs_bs2), (
assert len(logprobs_bs1) == len(logprobs_bsN), (
f"Different number of generation steps for prompt index {i}: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)"
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
)
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
assert a.shape == b.shape, (
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
)
Expand Down Expand Up @@ -297,6 +305,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
enable_prefix_caching=False,
enforce_eager=True,
# Enable for MOE models
# enable_expert_parallel=True,
)
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton

logger = init_logger(__name__)


def _matmul_launch_metadata(
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
Expand Down Expand Up @@ -564,5 +568,14 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
enable_batch_invariant_mode()
32 changes: 29 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
Expand All @@ -52,6 +55,7 @@
from vllm.v1.kv_cache_interface import AttentionSpec

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024

FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
Expand Down Expand Up @@ -290,6 +294,15 @@ def __init__(
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)

if vllm_kernel_override_batch_invariant():
self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096
self.disable_split_kv = True
else:
self.decode_fixed_split_size = -1
self.prefill_fixed_split_size = -1
self.disable_split_kv = False

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
self.model_config.max_model_len, self.kv_cache_spec.block_size
Expand Down Expand Up @@ -393,8 +406,11 @@ def __init__(

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_kernel_override_batch_invariant():
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device
buffer_size, dtype=torch.uint8, device=self.device
)
return self._workspace_buffer

Expand Down Expand Up @@ -671,6 +687,8 @@ def build(
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
Expand Down Expand Up @@ -732,6 +750,8 @@ def build(
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
return attn_metadata

Expand Down Expand Up @@ -1123,6 +1143,8 @@ def fast_plan_decode(
rope_scale: float | None = None,
rope_theta: float | None = None,
non_blocking: bool = True,
fixed_split_size: int = -1,
disable_split_kv: bool = False,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
Expand Down Expand Up @@ -1159,6 +1181,10 @@ def fast_plan_decode(
rope_scale,
rope_theta,
non_blocking,
None, # block_tables
None, # seq_lens
fixed_split_size,
disable_split_kv,
)
self.vllm_first_call = False
return
Expand Down Expand Up @@ -1224,8 +1250,8 @@ def fast_plan_decode(
head_dim,
False, # causal
window_left,
-1, # fixed_split_size
False, # disable_split_kv
fixed_split_size,
disable_split_kv,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
Expand Down