Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
955030c
pass mgr util
ProExpertProg Aug 27, 2025
d97634d
remove old multi_output_match workaround, add new workaround (tests p…
ProExpertProg Sep 3, 2025
c34cbdf
Improve workaround to properly handle inputs of inputs
ProExpertProg Sep 3, 2025
37446e5
Remove prints, add debug statement
ProExpertProg Sep 3, 2025
8d22102
Reduce fusion test running time
ProExpertProg Sep 3, 2025
7765aa8
Rename FusionPass, remove instance() workaround, other fusion pass cl…
ProExpertProg Sep 3, 2025
e8fae3d
Remove old QuantMultiOutputMatch
ProExpertProg Sep 5, 2025
4bdb73c
Remove MultiOutputMatch workaround, add cleanup pass
ProExpertProg Sep 6, 2025
4f4cf01
Add LazyInitPass
ProExpertProg Sep 6, 2025
e584fcd
Add pattern match debugging to tests
ProExpertProg Sep 6, 2025
a5c2e3b
Fix seq-par test, add assert to silu-fp4 for clearer error
ProExpertProg Sep 6, 2025
5fa85ea
more fusion.py cleanup
ProExpertProg Sep 9, 2025
e0872fd
Cleanup custom_op logging and add warnings for incorrectly specified ops
ProExpertProg Sep 9, 2025
11e33ad
Add decorator to time and log custom passes
ProExpertProg Sep 9, 2025
7b8a355
seq_par remove hack
ProExpertProg Sep 10, 2025
c4f2158
common superclass for pattern matcher passes, add matched_count to pa…
ProExpertProg Sep 10, 2025
57763f2
add match count assert to tests
ProExpertProg Sep 10, 2025
26ed452
Update tests/compile/test_fusion_attn.py
ProExpertProg Sep 10, 2025
5e65e5d
PR comments
ProExpertProg Sep 12, 2025
0bae653
Merge branch 'main' into luka/custom-pass-cleanup
ProExpertProg Sep 15, 2025
f2028f2
Add indices to test pass manager
ProExpertProg Sep 15, 2025
6c18f65
fix pre-commit
ProExpertProg Sep 15, 2025
950078e
cleanup pattern printing
ProExpertProg Sep 18, 2025
26cbe93
Merge branch 'main' into luka/custom-pass-cleanup
ProExpertProg Sep 19, 2025
6de810d
fix ci
ProExpertProg Sep 20, 2025
f2c1f97
reorder passes
ProExpertProg Sep 20, 2025
8b8f414
Merge branch 'main' into luka/custom-pass-cleanup
ProExpertProg Sep 20, 2025
ec3cac2
fix ci
ProExpertProg Sep 20, 2025
3dc0e04
Merge branch 'main' into luka/custom-pass-cleanup
ProExpertProg Sep 20, 2025
dba5449
Remove env
ProExpertProg Sep 22, 2025
7fcbe74
Remove env
ProExpertProg Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tests/compile/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import weakref
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
Expand All @@ -10,7 +11,26 @@

from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config


class LazyInitPass(InductorPass):
"""
If there's a pass that we want to initialize lazily in a test,
we can wrap it in LazyInitPass, which will initialize the pass when invoked
and then immediately invoke it.
"""

def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle

def __call__(self, graph: fx.Graph) -> None:
self.pass_ = self.pass_cls(self.vllm_config)
self.pass_(graph)


class TestBackend:
Expand Down Expand Up @@ -40,10 +60,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs):
example_inputs,
config_patches=self.inductor_config)

@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)

VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1

VllmInductorPass.dump_prefix = None

self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph
Expand Down
2 changes: 2 additions & 0 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)

assert async_tp_pass.matched_count == 1

# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
Expand Down
10 changes: 9 additions & 1 deletion tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer


Expand All @@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph


def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

with pytest.raises(ValueError, match="Invalid syntax '"):
_ = CompilationConfig(custom_ops=["quant_fp8"])


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
Expand Down
10 changes: 6 additions & 4 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
Expand Down Expand Up @@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
Expand Down
17 changes: 10 additions & 7 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import pytest
import torch

import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass)
RMSNormQuantFusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -79,15 +79,15 @@ def ops_in_model_after(self):


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
Expand All @@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)

# First dimension dynamic
Expand All @@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

assert fusion_pass.matched_count == 2

# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

Expand Down
6 changes: 5 additions & 1 deletion tests/compile/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
Expand All @@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
20 changes: 13 additions & 7 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
import pytest
import torch._dynamo

from tests.compile.backend import TestBackend
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,

# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
Expand Down Expand Up @@ -198,7 +199,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
device=self.device,
)

def build_attn_metadata(self, batch_size: int, use_hnd: bool):
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
"""Initialize attention metadata."""

# Create common attn metadata
Expand Down Expand Up @@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,

# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)

# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
Expand Down Expand Up @@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)

# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,
Expand Down
21 changes: 14 additions & 7 deletions tests/compile/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(self,
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)

self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(act_quant_static=True)

self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
Expand Down Expand Up @@ -137,8 +139,7 @@ def forward(self, hidden_states, residual):
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)

# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
Expand Down Expand Up @@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
dtype=dtype,
seed=42)

sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

passes_for_backend = [noop_pass, sequence_parallelism_pass]
passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]

if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)

passes_for_backend.append(cleanup_pass)

backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)

Expand All @@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)

assert sequence_parallelism_pass.matched_count == 1

# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
Expand Down
13 changes: 12 additions & 1 deletion tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):

def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
assert silu_and_mul_nvfp4_quant_supported

self.silu_and_mul = SiluAndMul()

# create nvfp4 weight
Expand Down Expand Up @@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)

backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch,
x=x)
Expand All @@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
atol=atol,
rtol=rtol)

assert fusion_pass.matched_count == 1

# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

Expand Down
Loading