Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Mamba 2 training speeds with higher d_state values #479

Open
jacob-morrison opened this issue Jul 18, 2024 · 2 comments
Open

Slow Mamba 2 training speeds with higher d_state values #479

jacob-morrison opened this issue Jul 18, 2024 · 2 comments

Comments

@jacob-morrison
Copy link

Hi! I'm training a small Mamba2 model (~60m non-embedding parameters to start), and I'm doing some benchmarks before committing to larger runs. Any ideas why I'm seeing slow speeds with higher d_state values? I'm importing the Mamba2 blocks directly from this repo, and I've tested Mamba2 vs Mamba2Simple and haven't noticed any differences there.

With different d_state values, I get:
d_state=128: 65k tokens/device/second (58m non-embedding parameters total)
d_state=64: 100k tokens/device/second (56m non-embedding parameters total)
d_state=32: 150k tokens/device/second (53.7m non-embedding parameters total)
d_state=16: 195k tokens/device/second (52.9m non-embedding parameters total)

All experiments are running on H100s using DDP, and I'm using these other parameters:

n_layers: 16
vocab_size: 50280
embedding_size: 50304
d_conv: 4
expand: 2
ngroups: 1 (have also tried 8)
headdim: 64

In comparison, we get ~350k tokens/second/device with a 60m transformer model with the same codebase.

@YaoMufeng
Copy link

I guess the higher d_state, the larger the model parameters.
Happy to see that you successfully runs the Mamba2
I meet many errors when installing mamba2.
Would you mind tell me your torch version, triton version, and causal_conv1d version?

@zh7117
Copy link

zh7117 commented Jul 31, 2024

I guess the higher d_state, the larger the model parameters. Happy to see that you successfully runs the Mamba2 I meet many errors when installing mamba2. Would you mind tell me your torch version, triton version, and causal_conv1d version?

I failed to build mamba-ssm from souce. But successully installed mamba-ssm and casual_conv1d by the wheel and run mamba2 successully.
cuda version: 11.8
torch version: 2.1.2; install command: pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
mamba wheel: https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
causal_conv1d wheel: https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
(you should select the wheel for your python version.)
triton version: 2.1.0

Hope this is helpful to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants