Skip to content

Loss NaN in Mamba2 #352

Closed
Closed
@tyshiwo1

Description

@tyshiwo1

Hello guys,

When I applied Mamba2 to image generation, I found several NaN values in the gradients (ddt_bias, dx, and ddt_given) in _mamba_chunk_scan_combined_bwd of mamba_ssm/ops/triton/ssd_combined.py, therefore the loss is NaN.

The image generation code is DiM. I just replaced the original Mamba-1 block with Mamba-2. I used the bf16 precision for training from scratch, and the NaN appears in the first training iteration.

My environment is triton==2.2.0, torch==2.2.1+cu121.

If anyone can help me, I will be very grateful!
nan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions