diff --git a/vllm/envs.py b/vllm/envs.py index 9485aeeb8a82..7d7fd22f85ca 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -143,7 +143,11 @@ VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True - VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_DEEP_GEMM_WARMUP: Literal[ + "skip", + "full", + "relax", + ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False @@ -1070,9 +1074,21 @@ def get_vllm_port() -> Optional[int]: # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine # startup time by a couple of minutes. - # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. - "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool( - int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0")) + # Available options: + # - "skip" : Skip warmup. + # - "full" : Warmup deepgemm by running all possible gemm shapes the + # engine could encounter. + # - "relax" : Select gemm shapes to run based on some heuristics. The + # heuristic aims to have the same effect as running all possible gemm + # shapes, but provides no guarantees. + "VLLM_DEEP_GEMM_WARMUP": env_with_choices( + "VLLM_DEEP_GEMM_WARMUP", + "relax", + [ + "skip", + "full", + "relax", + ], ), # Whether to use fused grouped_topk used for MoE expert selection. "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 1747caf26cef..f1ed2696a096 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -26,6 +26,55 @@ from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous +def _generate_optimal_warmup_m_values( + max_tokens: int, n: int, device: torch.device +) -> list[int]: + """ + Generate M values that cover all possible DeepGEMM kernel configurations. + Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp + + Args: + max_tokens: Maximum number of tokens to warmup for + n: The actual N dimension from the weight tensor + device: The torch device to get properties from. + """ + + def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + # DeepGEMM's possible block sizes + block_ms = [64, 128, 256] + block_ns = list(range(16, min(257, n + 1), 16)) + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + + m_values = set() + + # Always include small cases + m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)]) + + # Collect M values where different wave patterns occur + for block_m in block_ms: + for block_n in block_ns: + if block_n > n: + continue + + # Add key M boundaries for this block combination + for wave in range(1, 11): # Up to 10 waves + # M where this block config transitions to next wave + target_blocks = wave * num_sms + m = target_blocks * block_m // ceil_div(n, block_n) + if 1 <= m <= max_tokens: + m_values.add(m) + + # Add block_m boundaries + for multiple in range(1, max_tokens // block_m + 1): + m = multiple * block_m + if m <= max_tokens: + m_values.add(m) + + return sorted(m_values) + + def _extract_data_from_linear_base_module( m: torch.nn.Module, ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: @@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") - num_tokens = max_tokens - while num_tokens > 0: + # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". + # Otherwise warmup all token sizes to avoid JIT compilation in hotpath + if envs.VLLM_DEEP_GEMM_WARMUP == "relax": + m_values = _generate_optimal_warmup_m_values(max_tokens, n, device) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]" + else: + assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( + "Expected " + 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' + f"{envs.VLLM_DEEP_GEMM_WARMUP}" + ) + m_values = list(range(1, max_tokens + 1)) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]" + + pbar = tqdm(total=len(m_values), desc=desc) + + for num_tokens in m_values: fp8_gemm_nt( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] ) pbar.update(1) - num_tokens -= 1 FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) @@ -195,12 +257,16 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + # Generate M values in block_m increments (already optimized for MoE) + m_values = list(range(block_m, MAX_M + 1, block_m)) + pbar = tqdm( - total=MAX_BLOCKS, - desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", + total=len(m_values), + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) " + f"[{len(m_values)} values, block_m={block_m}]", ) - num_tokens = MAX_M - while num_tokens > 0: + + for num_tokens in m_values: m_grouped_fp8_gemm_nt_contiguous( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), @@ -208,7 +274,6 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): expert_ids[:num_tokens], ) pbar.update(1) - num_tokens = num_tokens - block_m for w, ws in [(w1, w1_scale), (w2, w2_scale)]: if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 23227065ee95..28792338f036 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"): do_deep_gemm_warmup = ( envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() - and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP + and envs.VLLM_DEEP_GEMM_WARMUP != "skip" ) if do_deep_gemm_warmup: model = worker.get_model()