Skip to content

Conversation

Josephasafg
Copy link
Contributor

@Josephasafg Josephasafg commented Sep 15, 2025

Purpose

There was a bug in the FP32 support recently added to mamba conv state where if you pass this arg --mamba-cache-dtype "float32" you'd get this error:

E       triton.compiler.errors.CompilationError: at 102:8:
E           w_base = w_ptr + (idx_feats * stride_w_dim)  # [BLOCK_N,]
E
E           # Does 2 things:
E           # 1. READ prior-block init-state data - [done by every Triton programs]
E           # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
E           if chunk_offset == 0:
E               # read from conv_states
E               load_init_state = False
E               if HAS_INITIAL_STATES:  # the new HAS_INITIAL_STATES
E                   load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
E                       tl.int1)
E               if load_init_state:
E               ^
E       AssertionError("Mismatched type for col0 between then block (<['256'], fp32>) and else block (<['256'], bf16>)")

The kernel needed to cast the input type to be the same as conv state.

Test Plan

Added two tests to test_hybrid.py as they were missing. Tests pass

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a Triton compilation error for Mamba's FP32 convolution state during prefill by casting the input tensor to match the state's data type. The fix in causal_conv1d_fn is appropriate. However, the same bug likely exists in the causal_conv1d_update function, which handles the decode path, and this has not been addressed. This could lead to failures during token generation after the prefill phase. I've added a critical comment highlighting this omission. The test enhancements are good, improving coverage for different cache dtype parameters.

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
@Josephasafg Josephasafg changed the title [Bug][Mamba] - Fix Conv State Kernel FP32 Support [Bugfix][Mamba] - Fix Conv State Kernel FP32 Support Sep 15, 2025
Copy link

mergify bot commented Sep 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Josephasafg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 17, 2025
@Josephasafg
Copy link
Contributor Author

@heheda12345 @tdoublep Can I please get a review?

@mergify mergify bot removed the needs-rebase label Sep 18, 2025
Comment on lines +978 to +979
original_x_dtype = x.dtype
x = x.to(conv_state.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we definitely want to cast x to the conv_state dtype, rather than casting conv_state to the x_dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a good question - Since the user picks fp32 for the cache type, Im afraid that by downcasting it to fp16 and then back to fp32, we could lose accuracy doing it. I had also figured that by choosing fp32, we want the computations to be done in that type dont we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work for the SSM state? I guess we want it to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cast to float in the kernel

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tdoublep tdoublep enabled auto-merge (squash) September 18, 2025 10:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 18, 2025
@tdoublep tdoublep merged commit 66072b3 into vllm-project:main Sep 18, 2025
53 checks passed
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
Signed-off-by: charlifu <charlifu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants