-
Notifications
You must be signed in to change notification settings - Fork 357
add a_1_128_w_128_128 (DeepSeek) float8 scaling for inference #3257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3257
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit c4769a6 with merge base 1e473ed ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: db464e1 ghstack-comment-id: 3460951962 Pull-Request: #3257
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c9e22bd ghstack-comment-id: 3460951962 Pull-Request: #3257
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 802d26f ghstack-comment-id: 3460951962 Pull-Request: #3257
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 81e336e ghstack-comment-id: 3460951962 Pull-Request: #3257
| triton.cdiv(N, meta["BLOCK_SIZE"]), | ||
| from torch.utils._triton import has_triton | ||
|
|
||
| if has_triton(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of the changes in this file is just indent from adding the if has_triton() statement
| mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
| tl.store(c_ptrs, c, mask=mask) | ||
|
|
||
| @torch.library.custom_op("ao::blockwise_fp8_gemm", mutates_args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-indent change 1
| ) | ||
| return c | ||
|
|
||
| @blockwise_fp8_gemm.register_fake |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-indent change 2
| fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) | ||
| return y | ||
|
|
||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-indent change 3
test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Outdated
Show resolved
Hide resolved
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8d58dfe ghstack-comment-id: 3460951962 Pull-Request: #3257
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, I think we can also keep the 1x128 block to align with what deepseek is calling it now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh since we are changing the meaning of PerBlock, please update the doc a bit as well:
ao/torchao/quantization/granularity.py
Line 107 in 1e473ed
| class PerBlock(Granularity): |
|
CI failures exist on main branch, landing |
Summary:
Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
comparison vs high precision baseline
For now we only have fallback kernels which requires triton and are numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
torch._scaled_mmfor CUDA 12.9+torch.compileis currently known slow for 128x128 blockquantization
Further accuracy testing and enablement of more features is left for future PRs, to keep PR size small.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: