[ROCM] Float8 deepseekv3_671b IntOverflow in triton kernels during training #4016
Merged
danielvegamyhre merged 2 commits intopytorch:mainfrom Mar 7, 2026
Merged
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4016
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cad25ac with merge base 42bcdc4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
|
|
||
| # cast to int64 to avoid overflow in pointer arithmetic for large tensors | ||
| block_row_offs_i64 = block_row_offs.to(tl.int64) |
Contributor
There was a problem hiding this comment.
You can just update the type annotation in the kernel signature to int64 i think, can you try
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Issue
When running training on deepseekv3_671b on 1 node reduced layers, triton kernels were producing IntOverflow exception after a few iterations.
File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 744, in run kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, File "/opt/conda/lib/python3.11/site-packages/triton/backends/amd/driver.py", line 828, in __call__ self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args) OverflowError: signed integer is greater than maximumChanges:
Two fixes applied to both _triton_fp8_per_group_rowwise_scales_kernel and _triton_fp8_per_group_colwise_scales_kernel:
Removed unused num_elements parameter — This was computed as hp_tensor.numel() and passed to the Triton kernels but never referenced in the kernel bodies. For DeepSeek V3's large MoE tensors (256 experts, dim=7168), numel() can exceed 2^31 - 1, and the AMD Triton driver packs int kernel args as signed 32-bit integers, directly causing the OverflowError: signed integer is greater than maximum at kernel launch.
Cast pointer arithmetic to tl.int64 — All stride multiplications and offset computations inside both kernels now use 64-bit integers (block_row_offs.to(tl.int64), stride_input_row.to(tl.int64), etc.). This prevents potential int32 overflow in pointer calculations like block_row_offs * stride_input_row for large activation tensors (e.g., ~1M routed tokens × stride 7168 ≈ 7B, which overflows int32).