Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_DEEP_GEMM_WARMUP: Literal[
"skip",
"full",
"relax",
] = "relax"
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
Expand Down Expand Up @@ -1070,9 +1074,21 @@ def get_vllm_port() -> Optional[int]:
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# startup time by a couple of minutes.
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
"VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(
int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))
# Available options:
# - "skip" : Skip warmup.
# - "full" : Warmup deepgemm by running all possible gemm shapes the
# engine could encounter.
# - "relax" : Select gemm shapes to run based on some heuristics. The
# heuristic aims to have the same effect as running all possible gemm
# shapes, but provides no guarantees.
"VLLM_DEEP_GEMM_WARMUP": env_with_choices(
"VLLM_DEEP_GEMM_WARMUP",
"relax",
[
"skip",
"full",
"relax",
],
),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
Expand Down
83 changes: 74 additions & 9 deletions vllm/model_executor/warmup/deep_gemm_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,55 @@
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous


def _generate_optimal_warmup_m_values(
max_tokens: int, n: int, device: torch.device
) -> list[int]:
"""
Generate M values that cover all possible DeepGEMM kernel configurations.
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp

Args:
max_tokens: Maximum number of tokens to warmup for
n: The actual N dimension from the weight tensor
device: The torch device to get properties from.
"""

def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
Comment on lines +42 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from vllm.utils import cdiv


# DeepGEMM's possible block sizes
block_ms = [64, 128, 256]
block_ns = list(range(16, min(257, n + 1), 16))
num_sms = torch.cuda.get_device_properties(device).multi_processor_count

m_values = set()

# Always include small cases
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])

# Collect M values where different wave patterns occur
for block_m in block_ms:
for block_n in block_ns:
if block_n > n:
continue

# Add key M boundaries for this block combination
for wave in range(1, 11): # Up to 10 waves
# M where this block config transitions to next wave
target_blocks = wave * num_sms
m = target_blocks * block_m // ceil_div(n, block_n)
if 1 <= m <= max_tokens:
m_values.add(m)

# Add block_m boundaries
for multiple in range(1, max_tokens // block_m + 1):
m = multiple * block_m
if m <= max_tokens:
m_values.add(m)

return sorted(m_values)


def _extract_data_from_linear_base_module(
m: torch.nn.Module,
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
Expand Down Expand Up @@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)

pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
num_tokens = max_tokens
while num_tokens > 0:
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
m_values = list(range(1, max_tokens + 1))
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"

pbar = tqdm(total=len(m_values), desc=desc)

for num_tokens in m_values:
fp8_gemm_nt(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
)
pbar.update(1)
num_tokens -= 1

FP8_GEMM_NT_WARMUP_CACHE.add(w.size())

Expand Down Expand Up @@ -195,20 +257,23 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)

# Generate M values in block_m increments (already optimized for MoE)
m_values = list(range(block_m, MAX_M + 1, block_m))

pbar = tqdm(
total=MAX_BLOCKS,
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})",
total=len(m_values),
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
f"[{len(m_values)} values, block_m={block_m}]",
)
num_tokens = MAX_M
while num_tokens > 0:

for num_tokens in m_values:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
(w, w_scale),
out[:num_tokens],
expert_ids[:num_tokens],
)
pbar.update(1)
num_tokens = num_tokens - block_m

for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/warmup/kernel_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"):
do_deep_gemm_warmup = (
envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
)
if do_deep_gemm_warmup:
model = worker.get_model()
Expand Down