Closed
Description
Hello guys,
When I applied Mamba2 to image generation, I found several NaN values in the gradients (ddt_bias
, dx
, and ddt_given
) in _mamba_chunk_scan_combined_bwd
of mamba_ssm/ops/triton/ssd_combined.py
, therefore the loss is NaN.
The image generation code is DiM. I just replaced the original Mamba-1 block with Mamba-2. I used the bf16 precision for training from scratch, and the NaN appears in the first training iteration.
My environment is triton==2.2.0, torch==2.2.1+cu121
.
Metadata
Metadata
Assignees
Labels
No labels