Skip to content

[ROCM] Float8 deepseekv3_671b IntOverflow in triton kernels during training #4016

Merged
danielvegamyhre merged 2 commits intopytorch:mainfrom
alex-minooka:float8-deepseek-int-overflow
Mar 7, 2026
Merged

[ROCM] Float8 deepseekv3_671b IntOverflow in triton kernels during training #4016
danielvegamyhre merged 2 commits intopytorch:mainfrom
alex-minooka:float8-deepseek-int-overflow

Conversation

@alex-minooka
Copy link
Copy Markdown
Contributor

Issue
When running training on deepseekv3_671b on 1 node reduced layers, triton kernels were producing IntOverflow exception after a few iterations.

File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 744, in run kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, File "/opt/conda/lib/python3.11/site-packages/triton/backends/amd/driver.py", line 828, in __call__ self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args) OverflowError: signed integer is greater than maximum

Changes:
Two fixes applied to both _triton_fp8_per_group_rowwise_scales_kernel and _triton_fp8_per_group_colwise_scales_kernel:

Removed unused num_elements parameter — This was computed as hp_tensor.numel() and passed to the Triton kernels but never referenced in the kernel bodies. For DeepSeek V3's large MoE tensors (256 experts, dim=7168), numel() can exceed 2^31 - 1, and the AMD Triton driver packs int kernel args as signed 32-bit integers, directly causing the OverflowError: signed integer is greater than maximum at kernel launch.

Cast pointer arithmetic to tl.int64 — All stride multiplications and offset computations inside both kernels now use 64-bit integers (block_row_offs.to(tl.int64), stride_input_row.to(tl.int64), etc.). This prevents potential int32 overflow in pointer calculations like block_row_offs * stride_input_row for large activation tensors (e.g., ~1M routed tokens × stride 7168 ≈ 7B, which overflows int32).

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4016

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cad25ac with merge base 42bcdc4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 6, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 6, 2026 18:40
block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# cast to int64 to avoid overflow in pointer arithmetic for large tensors
block_row_offs_i64 = block_row_offs.to(tl.int64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just update the type annotation in the kernel signature to int64 i think, can you try

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

@danielvegamyhre danielvegamyhre self-requested a review March 6, 2026 18:41
@danielvegamyhre danielvegamyhre merged commit 5045d76 into pytorch:main Mar 7, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: rocm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants