-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss NaN in Mamba2 #352
Comments
很抱歉打扰你,我没有遇到和你一样的问题,但是我想请教你一些问题。 1.你有没有遇到下面这种问题。 2.mamba2你是怎么直接用于自己项目的?是下载他那个whl文件,然后更新你虚拟环境中的mamba_ssm包吗?还是下载他整个项目文件,然后用的是项目文件里的mamba2文件? 很抱歉打扰你,期待回复捏~ |
|
Yes, the same issue |
I found this line of code causes NaN of https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py#L176 |
have you try to use float32 instead of bfloat16 ? |
The NaN remains when I use fp32 to train. |
wow,Do you know how to change this line of code (acc += tl.dot(cb, dout)) to solve the nan problem? (I'm so bad at it, I don't know how to do it.) |
I am trying on it. Also, the code around this line may also cause bug. |
不明觉厉,我自己也研究研究。 那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)? |
I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss. |
Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us?
So that we can reproduce it like this:
|
Thank you for your reply !!! The error occurs on this line, so I save the input and output tensors into a zip file. I also tried some operations like changing
into
, but this does not work on other variables like |
Thanks for the suggestion, I set the chunk_size to 1 and then the time of the NaN appeared backward, but it still appeared as a NaN. |
I think |
It did. |
This can stabilize the training of the first few iterations, but the loss becomes NaN later. |
Maybe you can also locate the NaN values and provide the tensors like I did 😂. |
if it is still not stable, the last method is to lower down the learning rate |
Tried. The result is just the same as decreasing chunk size. It can stabilize for more epochs but loss becomes nan later. |
@ZijianYY You can follow the instructions here. mamba/mamba_ssm/ops/triton/ssd_combined.py Line 403 in 26283fb
. You may also check it. |
Here's another a very primitive barebone "model" that very quickly generates NaN at 3080Ti laptop. The model creates random 8x8 "RGB image" and then try to create "upscaled" 64x64 version Number of layers and d_model matter. With mamba2simple I get Start of file sets parameters(the most important is THE_MAMBA to choose what class is used for mamba: Mamba, Mamba2, Mamba2Simple) |
Same here, raised nan when training using Mamba2 |
i also get NaN on image classification taks :( |
So is this problem still unresolved?(i get NaN on voice classification tasks) |
Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4? |
My sequence length is 256, my task is object detection, it raised nan 😭
…---- Replied Message ----
| From | Tri ***@***.***> |
| Date | 06/12/2024 14:39 |
| To | state-spaces/mamba ***@***.***> |
| Cc | EddieEduardo ***@***.***>,
Comment ***@***.***> |
| Subject | Re: [state-spaces/mamba] Loss NaN in Mamba2 (Issue #352) |
Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you commented.Message ID: ***@***.***>
|
Please make a reproducible script (e.g. save the tensors right before the function that causes NaN). If we can't reproduce we can't do anything. |
Thank you! We will try it. |
This fix works for me. At least the loss remains stable over a few thousand training iterations. Thanks again! |
Huge fix, training works on d_model = 256 and 512. but once I lower d_model down to 192, the error
was produced, if I follow #362 to force the training to run, it will produce nan on any d_model. Any plan on fixing this issue? Let me know what you need. |
Works now! |
my bad for not investigate further, your fix is perfect and we should not use #362 with it, all we need is to make sure d_model * expand / headdim = multiple of 8 |
yes ,it did work for me |
After running for a while, Nan is restored |
Hello guys,
When I applied Mamba2 to image generation, I found several NaN values in the gradients (
ddt_bias
,dx
, andddt_given
) in_mamba_chunk_scan_combined_bwd
ofmamba_ssm/ops/triton/ssd_combined.py
, therefore the loss is NaN.The image generation code is DiM. I just replaced the original Mamba-1 block with Mamba-2. I used the bf16 precision for training from scratch, and the NaN appears in the first training iteration.
My environment is
triton==2.2.0, torch==2.2.1+cu121
.If anyone can help me, I will be very grateful!
The text was updated successfully, but these errors were encountered: