[mxfp8] bug fix and test updates related to new triton_calculate_scale param#3522
[mxfp8] bug fix and test updates related to new triton_calculate_scale param#3522danielvegamyhre merged 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3522
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated FailureAs of commit dca46b4 with merge base 7035fb7 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
903f72e to
795e430
Compare
795e430 to
6124e58
Compare
|
cc @drisspg for review |
|
|
||
| def triton_to_mxfp8_dim1_reference( | ||
| x_hp: torch.Tensor, block_size | ||
| x_hp: torch.Tensor, |
There was a problem hiding this comment.
if we can we should try to have 1 default arg at the top user facing and then not populate for inner funcs should help with these types of bug
There was a problem hiding this comment.
the launcher functions do have defaults for scaling_mode (rceil) that are passed in to the actual kernels. (let me know if i misunderstood your comment)
There was a problem hiding this comment.
oh you mean in this torch reference impl? i will add a default.
triton_calculate_scaleparam needs to be updated, this slipped through in [mxfp8] support RCEIL in triton_to_mxfp8_dim0 kernel with inline PTX #3498, motivation for getting blackwells in CI soon!Tests
pytest test/prototype/mx_formats/test_kernels.pyverified whole test suite is passing locally this time