-
Notifications
You must be signed in to change notification settings - Fork 349
mxtensor: add pre-swizzle support #3200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3200
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Adds the ability to pre-swizzle scales for `MXTensor`, and turns it on for the inference workflow. For activations, this is no-change for now but if we write a fused kernel we'll hook into the pre-swizzled path. For weights, this is a performance win in this PR as now we swizzle ahead of time. Rough magnitude of the weight pre-swizzling win: on M, K, N == 4096, 4096, 4096, the inference fwd speedup on mxfp8 increases from 1.24x to 1.30x Test Plan: ```bash // correctness CUDA_VISIBLE_DEVICES=5 pytest test/prototype/mx_formats/ -s // performance CUDA_VISIBLE_DEVICES=5 python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/20251017_test.csv --recipe_name mxfp8_cublas --shape_gen_name pow2_extended // before: https://www.internalfb.com/phabricator/paste/view/P1996942931 // after: https://www.internalfb.com/phabricator/paste/view/P1996941798 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 46b8d23 ghstack-comment-id: 3415966576 Pull-Request: #3200
torch.float8_e5m2, | ||
torch.uint8, | ||
), "unsupported" | ||
if elem_dtype in ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code doesn't really have a strong purpose, removing instead of making it handle swizzling
return mx_tensor | ||
|
||
|
||
def _swizzle_aware_slice( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extracted this out, the only things that change are the various shape calculations (fp8 vs fp4 data, 32 vs 16 block size)
Output: sliced qdata and scale, does the right thing for unswizzled and swizzled scales | ||
""" | ||
|
||
M, K = x.shape[0], x.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit I should probably have used a (m/n) generic term like rows, columns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, nice refactor
Summary:
Adds the ability to pre-swizzle scales for
MXTensor
,and turns it on for the inference workflow.
For activations, this is no-change for now but if we write a fused
kernel we'll hook into the pre-swizzled path.
For weights, this is a performance win already as now we swizzle ahead of
time.
Rough magnitude of the weight pre-swizzling win:
on M, K, N == 4096, 4096, 4096, the inference fwd speedup on mxfp8
increases from 1.24x to 1.30x
Name of
_is_swizzled_scales
is not final, but IMO we should finalize it in a future PR together with NVFP4Tensor. For now I'm staying consistent with NVFP4Tensor.Test Plan:
// correctness CUDA_VISIBLE_DEVICES=5 pytest test/prototype/mx_formats/ -s // performance CUDA_VISIBLE_DEVICES=5 python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/20251017_test.csv --recipe_name mxfp8_cublas --shape_gen_name pow2_extended // before: https://www.internalfb.com/phabricator/paste/view/P1996942931 // after: https://www.internalfb.com/phabricator/paste/view/P1996941798
Reviewers:
Subscribers:
Tasks:
Tags: