Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,17 @@ def test_full_cuda_graph(
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state(
@pytest.mark.parametrize("cache_dtype_param",
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
def test_fp32_cache_state(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
cache_dtype_param: str,
) -> None:

try:
Expand All @@ -443,13 +446,13 @@ def test_fp32_state(
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
**{cache_dtype_param: "float32"}) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
**{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def causal_conv1d_fn(
activation = "silu"

args = None
# Store original dtype to cast back at the end
original_x_dtype = x.dtype
x = x.to(conv_states.dtype)
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
Expand Down Expand Up @@ -613,7 +616,7 @@ def grid(META):
BLOCK_N=256,
num_stages=2,
)
return out
return out.to(original_x_dtype)


@triton.jit()
Expand Down Expand Up @@ -971,6 +974,9 @@ def causal_conv1d_update(
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]

original_x_dtype = x.dtype
x = x.to(conv_state.dtype)
Comment on lines +978 to +979
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we definitely want to cast x to the conv_state dtype, rather than casting conv_state to the x_dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a good question - Since the user picks fp32 for the cache type, Im afraid that by downcasting it to fp16 and then back to fp32, we could lose accuracy doing it. I had also figured that by choosing fp32, we want the computations to be done in that type dont we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work for the SSM state? I guess we want it to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cast to float in the kernel

unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
Expand Down Expand Up @@ -1079,4 +1085,4 @@ def grid(META):
)
if unsqueeze:
out = out.squeeze(-1)
return out
return out.to(original_x_dtype)