From 3148be5ac93336e7972e7658f7c5c542363783db Mon Sep 17 00:00:00 2001 From: asafg <39553475+Josephasafg@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:04:17 +0300 Subject: [PATCH 1/2] Add support for fp32 to conv state kernel Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> --- tests/models/language/generation/test_hybrid.py | 8 +++++--- vllm/model_executor/layers/mamba/ops/causal_conv1d.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index d0e42062099e..dab359cd87ba 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -418,7 +418,8 @@ def test_full_cuda_graph( @pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_fp32_state( +@pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_fp32_cache_state( hf_runner, vllm_runner, example_prompts, @@ -426,6 +427,7 @@ def test_fp32_state( model: str, max_tokens: int, num_logprobs: int, + cache_dtype_param: str, ) -> None: try: @@ -443,13 +445,13 @@ def test_fp32_state( m.setenv("VLLM_USE_V1", "0") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: + **{cache_dtype_param: "float32"}) as vllm_model: vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: + **{cache_dtype_param: "float32"}) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc37105..9480f0abeeba 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -415,6 +415,9 @@ def causal_conv1d_fn( activation = "silu" args = None + # Store original dtype to cast back at the end + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen @@ -611,7 +614,7 @@ def grid(META): BLOCK_N=256, num_stages=2, ) - return out + return out.to(original_x_dtype) @triton.jit() From b81854a4f58ab0e21945183039ef6cc64482f6bb Mon Sep 17 00:00:00 2001 From: asafg <39553475+Josephasafg@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:48:52 +0300 Subject: [PATCH 2/2] Added fp32 cast to conv update Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> --- tests/models/language/generation/test_hybrid.py | 3 ++- vllm/model_executor/layers/mamba/ops/causal_conv1d.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index dab359cd87ba..206ad1352e06 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -418,7 +418,8 @@ def test_full_cuda_graph( @pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) def test_fp32_cache_state( hf_runner, vllm_runner, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 9480f0abeeba..201b63928e14 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -859,6 +859,9 @@ def causal_conv1d_update( activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) unsqueeze = x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 @@ -945,4 +948,4 @@ def grid(META): ) if unsqueeze: out = out.squeeze(-1) - return out + return out.to(original_x_dtype)