-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[PERF] Add conv1d
metadata to GDN attn
#25105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimization for the Qwen3-next model by adding mamba2-style convolution metadata to the GDN attention mechanism. This change aims to reduce GPU-host memory transfers in causal_conv1d_fn
, leading to a significant performance improvement as demonstrated by the benchmark results. The implementation correctly extends GDNAttentionMetadata
and integrates the metadata preparation and usage within the model's forward pass. The changes are well-targeted and effective. I have one minor suggestion to improve type hint correctness.
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
conv1d
metadata to GDN attn
Could someone please run the CI? |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: charlifu <charlifu@amd.com>
Purpose
The on of the main gap in performance for Qwen3-next model is low utilisation of GPU in GDN attn.
Low utilization caused by several GPU<->host memory transfers.
These transfers caused by
causal_conv1d_fn
.Qwen3-next haven't supported this specific convolution metadata. The purpose of the metadata is avoid these memory transfer.
Add conv metadata (similar to mamba2).
+ corrected
tensor.tensor
->tensor.Tensor
types annotation.Test Result
H200, tp=4
Before
After
Speedup is 26%