Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss NaN in Mamba2 #352

Closed
tyshiwo1 opened this issue Jun 4, 2024 · 37 comments
Closed

Loss NaN in Mamba2 #352

tyshiwo1 opened this issue Jun 4, 2024 · 37 comments

Comments

@tyshiwo1
Copy link

tyshiwo1 commented Jun 4, 2024

Hello guys,

When I applied Mamba2 to image generation, I found several NaN values in the gradients (ddt_bias, dx, and ddt_given) in _mamba_chunk_scan_combined_bwd of mamba_ssm/ops/triton/ssd_combined.py, therefore the loss is NaN.

The image generation code is DiM. I just replaced the original Mamba-1 block with Mamba-2. I used the bf16 precision for training from scratch, and the NaN appears in the first training iteration.

My environment is triton==2.2.0, torch==2.2.1+cu121.

If anyone can help me, I will be very grateful!
nan

@zzzendurance
Copy link

很抱歉打扰你,我没有遇到和你一样的问题,但是我想请教你一些问题。

1.你有没有遇到下面这种问题。
File "/data/zh/miniconda3/envs/man/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/data/zh/wa1/-main/IP/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
(arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor
这个我查到是因为causal_conv1d版本问题,我的causal_conv1d版本原本是1.1.1,我在更换版本为1.0.2后,他依旧报错。你的causal_conv1d版本是什么?

2.mamba2你是怎么直接用于自己项目的?是下载他那个whl文件,然后更新你虚拟环境中的mamba_ssm包吗?还是下载他整个项目文件,然后用的是项目文件里的mamba2文件?
如果你下载的是whl文件,那你是怎么选版本的?这里的abi 真否有何区别
mamba ssm-2.0.3+cu118torch1.13cxx11abiTRUE-cp310-cp310-inux x86 64.whl
mamba ssm-2.0.3+cu118torch1.13cxx11abiFALSE-cp310-cp310-inux x86 64.whl

很抱歉打扰你,期待回复捏~

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

  1. my casual conv 1d version is 1.2.2.post1
  2. I download the whole project and compile it locally using CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .

@zzzendurance
Copy link

nan
哈哈哈哈 ,感谢,我更新了我的casual conv 1d version为1.2.2.post1,成功跑起来了,然后遇到了和你一样的问题?之前mamba1同样的超参数没有nan过好像

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

Yes, the same issue

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

@Kiet0712
Copy link

Kiet0712 commented Jun 4, 2024

have you try to use float32 instead of bfloat16 ?

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

have you try to use float32 instead of bfloat16 ?

The NaN remains when I use fp32 to train.

@zzzendurance
Copy link

wow,Do you know how to change this line of code (acc += tl.dot(cb, dout)) to solve the nan problem? (I'm so bad at it, I don't know how to do it.)

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

wow,Do you know how to change this line of code (acc += tl.dot(cb, dout)) to solve the nan problem? (I'm so bad at it, I don't know how to do it.)

I am trying on it. Also, the code around this line may also cause bug.
Besides, other variables like ddt_bias also contain NaN. Maybe a lot of codes need to change.

@zzzendurance
Copy link

不明觉厉,我自己也研究研究。

那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

不明觉厉,我自己也研究研究。

那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

@tridao
Copy link
Collaborator

tridao commented Jun 4, 2024

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us?
Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 4, 2024

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.

Thank you for your reply !!!
I have uploaded my tensors to Google drive.

The error occurs on this line, so I save the input and output tensors into a zip file.

I also tried some operations like changing

dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)

into

dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=-1e6).to(tl.float32)

, but this does not work on other variables like dout.

@XiudingCai
Copy link

A smaller A range (close to 1) and a smaller chunk size may make the training more stable

Thanks for the suggestion, I set the chunk_size to 1 and then the time of the NaN appeared backward, but it still appeared as a NaN.

@realwenlongwang
Copy link

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

@XiudingCai
Copy link

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

It did.

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 5, 2024

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

This can stabilize the training of the first few iterations, but the loss becomes NaN later.

@ZijianYY
Copy link

ZijianYY commented Jun 5, 2024

same issue. In my code, the loss suddenly becomes nan.
截屏2024-06-05 下午9 32 24

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 5, 2024

same issue. In my code, the loss suddenly becomes nan.

Maybe you can also locate the NaN values and provide the tensors like I did 😂.

@bio-mlhui
Copy link

if it is still not stable, the last method is to lower down the learning rate

@ZijianYY
Copy link

ZijianYY commented Jun 6, 2024

same issue. In my code, the loss suddenly becomes nan.

Maybe you can also locate the NaN values and provide the tensors like I did 😂.

Checked. It is also the ddt_bias tensor as you mentioned before.
ddt_bias

@ZijianYY
Copy link

ZijianYY commented Jun 6, 2024

if it is still not stable, the last method is to lower down the learning rate

Tried. The result is just the same as decreasing chunk size. It can stabilize for more epochs but loss becomes nan later.

@tyshiwo1
Copy link
Author

tyshiwo1 commented Jun 6, 2024

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.

@ZijianYY You can follow the instructions here.
I have uploaded my tensors to
https://drive.google.com/drive/folders/1ojmQNDsAToNZaP3ZNOAeMJBu1AshMnXS?usp=sharing
for this function

dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)

. You may also check it.

@Maykeye
Copy link

Maykeye commented Jun 6, 2024

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

Here's another a very primitive barebone "model" that very quickly generates NaN at 3080Ti laptop.

The model creates random 8x8 "RGB image" and then try to create "upscaled" 64x64 version
using mamba and 64x64 random values that are supposed to represent "I'm a pixel N-th, who am I considering the past?".

Number of layers and d_model matter.
dtype doesn't (both bfloat16 and float32 fail at the same epoch). expand, d_state etc are default

With mamba2simple I get ValueError: NaN loss at epoch #2
With mamba2 I get ValueError: NaN loss at epoch #1
With mamba1 I lose patience after 100 iterations: no NaN appears.

Start of file sets parameters(the most important is THE_MAMBA to choose what class is used for mamba: Mamba, Mamba2, Mamba2Simple)

@EddieEduardo
Copy link

Same here, raised nan when training using Mamba2

@catalpaaa
Copy link

不明觉厉,我自己也研究研究。
那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

i also get NaN on image classification taks :(

@zzzendurance
Copy link

不明觉厉,我自己也研究研究。
那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

i also get NaN on image classification taks :(

So is this problem still unresolved?(i get NaN on voice classification tasks)

@tridao
Copy link
Collaborator

tridao commented Jun 12, 2024

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

@EddieEduardo
Copy link

EddieEduardo commented Jun 12, 2024 via email

@tridao
Copy link
Collaborator

tridao commented Jun 12, 2024

My sequence length is 256, my task is object detection, it raised nan 😭

---- Replied Message ---- | From | Tri @.> | | Date | 06/12/2024 14:39 | | To | state-spaces/mamba @.> | | Cc | EddieEduardo @.>, Comment @.> | | Subject | Re: [state-spaces/mamba] Loss NaN in Mamba2 (Issue #352) | Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

Please make a reproducible script (e.g. save the tensors right before the function that causes NaN). If we can't reproduce we can't do anything.

@tyshiwo1
Copy link
Author

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Thank you! We will try it.

@tyshiwo1
Copy link
Author

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Thank you! We will try it.

This fix works for me. At least the loss remains stable over a few thousand training iterations. Thanks again!

@catalpaaa
Copy link

catalpaaa commented Jun 12, 2024

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Huge fix, training works on d_model = 256 and 512. but once I lower d_model down to 192, the error

RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

was produced, if I follow #362 to force the training to run, it will produce nan on any d_model.

Any plan on fixing this issue? Let me know what you need.

@Maykeye
Copy link

Maykeye commented Jun 12, 2024

We pushed a fix, can you guys try v2.0.4?

Works now!

@catalpaaa
Copy link

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Huge fix, training works on d_model = 256 and 512. but once I lower d_model down to 192, the error

RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

was produced, if I follow #362 to force the training to run, it will produce nan on any d_model.

Any plan on fixing this issue? Let me know what you need.

my bad for not investigate further, your fix is perfect and we should not use #362 with it, all we need is to make sure d_model * expand / headdim = multiple of 8

@TimothyChen225
Copy link

  1. my casual conv 1d version is 1.2.2.post1
  2. I download the whole project and compile it locally using CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .

yes ,it did work for me

@drhuangliwei
Copy link

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

After running for a while, Nan is restored

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests