Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On the small model, the actual GPU memory usage of Mamba2 is much higher than that of Mamba1. #439

Open
AlwaysFHao opened this issue Jul 2, 2024 · 17 comments

Comments

@AlwaysFHao
Copy link

AlwaysFHao commented Jul 2, 2024

The parameters of the Mamba2 model are d_state=32, d_conv=4, expand=2, and head_dim=32 (using "nn. Conv1d" with padding method, without the constraint of d_model/head_dim%8==0). Mamba1 maintains the same parameters except for the absence of head_dim. Although the inference speed of Mamba2 has almost doubled compared to Mamba1, the actual memory usage has increased from 4.82G to 7.55G (in my task). I would like to ask if this is due to the basic computational load of Mamba2's semi separation matrix, which poses a disadvantage in small-scale models? I see in your paper that on larger scale models, the actual memory usage of Mamba2 is lower.

@tridao
Copy link
Collaborator

tridao commented Jul 2, 2024

nn.Conv1d is probably not great for memory usage. You should try to use causal_conv1d.

@AlwaysFHao
Copy link
Author

nn.Conv1d is probably not great for memory usage. You should try to use causal_conv1d.

Okay, thank you for your reply,I am trying to use casual_conv1d on the same scale, keeping all other parameters unchanged, and only changing the head_dim to 4 (to meet the requirement of d_state/head_dim% 8==0). In my understanding, this should actually reduce the number of parameters (similar to DWConv?), but now the actual GPU memory usage is 8.12GB, which is far more than Mamba1's 4.82GB memory usage (in my task). I am not sure if this is due to warm up optimization or the SSD's built-in semi separation matrix basic load. Could you please give me some guidance? Thank you very much!

@tridao
Copy link
Collaborator

tridao commented Jul 3, 2024

huh there's no requirement d_state / head_dim % 8 == 0
there's d_model / head_dim % 8 == 0
you can try the dimensions similar to the language models we've released (e.g. d_model = 1024)
I don't have much experience with the kind of small dimensions you're working with

@AlwaysFHao
Copy link
Author

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Oh! Yes, it should be d_model/head_dim% 8==0. I looked at the source code again and found that I had confused d_model with d_state before. I'm really sorry. I will try a high-dimensional experiment next and give you feedback later. Additionally, there is an issue with the code in your mamba2 source code that uses "nn. Conv1d" with a padding scheme (to replace casual_conv1d). Please refer to another issue I raised for details #437 .

@AlwaysFHao
Copy link
Author

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

@AlwaysFHao
Copy link
Author

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

If possible, could you release an official version of an SSD based on Python instead of Triton? It seems that ssd_minimal does not have a discretization based implementation, and due to personal limitations, I cannot guarantee that I can achieve an equivalent version of your triton implementation. Thank you very much!

@tridao
Copy link
Collaborator

tridao commented Jul 4, 2024

We already have a reference implementation:

def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):

@AlwaysFHao
Copy link
Author

We already have a reference implementation:

def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):

Thank you for your great work. I am currently trying to use ssd_chunk_scan_combined_ref, but there is an assertion in chunk_scan_ref (line 1806 in ssd_chunk_state.py):
assert seqlen == nchunks * chunk_size
I don't quite understand why this restriction doesn't seem to be present when using the triton version? According to the SSD architecture, there should indeed be this limitation, but it seems that I did not consider this issue when using the Triton version before. Did you add any optimization methods in the Triton version? Could you please help me dispel my doubts? Thank you very much!

@tridao
Copy link
Collaborator

tridao commented Jul 5, 2024

You can always pad the seqlen. The assert seqlen == nchunks * chunk_size for the reference is there for simplicity of implemtatnion. This ref version is not used to train models, only for testing.
The triton version implicitly pads things inside the kernel it supports all kinds of seqlen.

@AlwaysFHao
Copy link
Author

You can always pad the seqlen. The assert seqlen == nchunks * chunk_size for the reference is there for simplicity of implemtatnion. This ref version is not used to train models, only for testing. The triton version implicitly pads things inside the kernel it supports all kinds of seqlen.

Thank you for your prompt reply! I have seen the code related to pad the seqlen, and I will continue to study the relevant content. In addition, in my task, it seems that under the same number of parameters, Mamba1 always performs better than Mamba2. However, as the feature dimension increases, Mamba2's training and inference speed advantages will become larger, and the difference in memory usage between Mamba2 and Mamba1 will also become smaller (Mamba2 memory usage will be larger in low dimensions). I have been testing until the feature dimension of 512, and the memory usage of Mamba2 and Mamba1 will be almost equal. By adjusting chunk_size, I found that the usage of graphics memory significantly changes with the number of partitions, so I think the disadvantage of Mamba2 in low dimensions should involve the basic consumption problem of semi separation matrices. If you also agree with my immature ideas, may I ask if Mamba2 can be partially optimized in this regard?

@TimothyChen225
Copy link

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecturhow

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

how did you accelerate mamba2? warm up strategy?

@AlwaysFHao
Copy link
Author

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecturhow

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

how did you accelerate mamba2? warm up strategy?

yes

@dumpmemory
Copy link

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecturhow

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

how did you accelerate mamba2? warm up strategy?

yes

how can u caclulate mamba2 flops ?

@AlwaysFHao
Copy link
Author

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecturhow

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

how did you accelerate mamba2? warm up strategy?

yes

how can u caclulate mamba2 flops ?

The recbole.utils.get_flops method in the recbole framework.
https://github.com/RUCAIBox/RecBole/blob/2b6e209372a1a666fe7207e6c2a96c7c3d49b427/recbole/utils/utils.py#L250

@dumpmemory
Copy link

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecturhow

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

Hello, I have continued the experiment on a small dimension and the model parameter is d_model=128, d_state=32, d_conv=4, expand=2, and head_dim=32, Stack a total of 4 layers of encoders(emphasizing the use of casual_conv1d instead of "nn. Conv1d", and due to device limitations, I am unable to conduct experiments in the only dimension of 1028). Under the same parameters as much as possible, in my task, the flops of the Mamba2 model are 60083328.0, with a parameter count of 876256. The actual memory usage of Mamba2 is 16.32GB, and the training time is 265.22s. The flops of the Mamba1 model are 281728.0, with a parameter count of 953728. The actual memory usage of Mamba1 is 13.35G, and the model training time is 586.78s. I don't quite understand why Mamba2 has a smaller actual parameter count but higher memory usage, while Mamba1 has a smaller actual floats but slower inference time. From your experiments on high-dimensional dimensions, it seems that Mamba2's memory usage should be smaller. I suspect that it may be because Mamba2 requires some optimization based on triton programming, resulting in higher memory usage? Also, is it caused by the basic load of the semi separation matrix in the SSD architecture?

how did you accelerate mamba2? warm up strategy?

yes

how can u caclulate mamba2 flops ?

The recbole.utils.get_flops method in the recbole framework. https://github.com/RUCAIBox/RecBole/blob/2b6e209372a1a666fe7207e6c2a96c7c3d49b427/recbole/utils/utils.py#L250

did recbole take the triton ops flops into account?

@chairman-lu
Copy link

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

If possible, could you release an official version of an SSD based on Python instead of Triton? It seems that ssd_minimal does not have a discretization based implementation, and due to personal limitations, I cannot guarantee that I can achieve an equivalent version of your triton implementation. Thank you very much!

Actually, I'm not quite sure why it's said that ssd_minimal doesn't have a discrete implementation. Additionally, I tested the accuracy difference between ssd_minimal_discrete and mamba_chunk_scan_combined, and the maximum difference can reach 0.05. Why is this the case?

@AlwaysFHao
Copy link
Author

AlwaysFHao commented Aug 28, 2024

ssd_minimal

huh there's no requirement d_state / head_dim % 8 == 0 there's d_model / head_dim % 8 == 0 you can try the dimensions similar to the language models we've released (e.g. d_model = 1024) I don't have much experience with the kind of small dimensions you're working with

If possible, could you release an official version of an SSD based on Python instead of Triton? It seems that ssd_minimal does not have a discretization based implementation, and due to personal limitations, I cannot guarantee that I can achieve an equivalent version of your triton implementation. Thank you very much!

Actually, I'm not quite sure why it's said that ssd_minimal doesn't have a discrete implementation. Additionally, I tested the accuracy difference between ssd_minimal_discrete and mamba_chunk_scan_combined, and the maximum difference can reach 0.05. Why is this the case?

From the code implementation provided by the official blog and code, it can be seen that ssd_minimal only considers the calculations required by the ssd kernel, without taking into account the calculation steps for discretizing the A and B matrices. Therefore, ssd_minimal does not have a discretization process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants