-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support #24883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3148be5
b81854a
94b2fe5
bfdaf6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -415,6 +415,9 @@ def causal_conv1d_fn( | |
activation = "silu" | ||
|
||
args = None | ||
# Store original dtype to cast back at the end | ||
original_x_dtype = x.dtype | ||
x = x.to(conv_states.dtype) | ||
out = torch.empty_like(x) | ||
if metadata is not None: | ||
cu_seqlen = metadata.cu_seqlen | ||
|
@@ -613,7 +616,7 @@ def grid(META): | |
BLOCK_N=256, | ||
num_stages=2, | ||
) | ||
return out | ||
return out.to(original_x_dtype) | ||
|
||
|
||
@triton.jit() | ||
|
@@ -971,6 +974,9 @@ def causal_conv1d_update( | |
activation = "silu" if activation is True else None | ||
elif activation is not None: | ||
assert activation in ["silu", "swish"] | ||
|
||
original_x_dtype = x.dtype | ||
x = x.to(conv_state.dtype) | ||
Comment on lines
+978
to
+979
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we definitely want to cast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its a good question - Since the user picks fp32 for the cache type, Im afraid that by downcasting it to fp16 and then back to fp32, we could lose accuracy doing it. I had also figured that by choosing fp32, we want the computations to be done in that type dont we? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does it work for the SSM state? I guess we want it to be consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we cast to float in the kernel |
||
unsqueeze = query_start_loc is None and x.dim() == 2 | ||
if unsqueeze: | ||
# make it (batch, dim, seqlen) with seqlen == 1 | ||
|
@@ -1079,4 +1085,4 @@ def grid(META): | |
) | ||
if unsqueeze: | ||
out = out.squeeze(-1) | ||
return out | ||
return out.to(original_x_dtype) |
Uh oh!
There was an error while loading. Please reload this page.