(prototype) add support for MXFP8 and MXFP4 QAT#3644
(prototype) add support for MXFP8 and MXFP4 QAT#3644andrewor14 merged 23 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3644
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 02d271e with merge base 30fcb15 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ], | ||
| ) | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_reconstruction(dtype, shape): |
There was a problem hiding this comment.
this should be already covered by existing tests, remove?
| ], | ||
| ) | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_scaling_modes(scaling_mode): |
There was a problem hiding this comment.
this should be already covered by existing tests, remove?
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_fake_quantize_config(): |
There was a problem hiding this comment.
move to test/quantization/test_qat.py
| "scaling_mode", | ||
| [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL], | ||
| ) | ||
| def test_mxfp4_fake_quantized_linear_forward(bias, input_shape, scaling_mode): |
There was a problem hiding this comment.
move to test/quantization/test_qat.py
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.parametrize("bias", [True, False]) | ||
| def test_mxfp4_fake_quantized_linear_backward(bias): |
There was a problem hiding this comment.
move to test/quantization/test_qat.py
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_fake_quantized_linear_to_linear(): |
There was a problem hiding this comment.
move to test/quantization/test_qat.py
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_vs_nvfp4_block_size(): |
There was a problem hiding this comment.
this is already covered by existing tests, remove?
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_config_error_handling(): |
| ], | ||
| ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}", | ||
| ) | ||
| def test_mxfp4_matmul_sqnr(shapes): |
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| def test_mxfp4_training_simulation(): |
torchao/prototype/qat/mxfp4.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class MXFP4FakeQuantizeConfig(FakeQuantizeConfigBase): |
There was a problem hiding this comment.
these should handle mxfp8 and mxfp4 (see MXTensor) instead of being hardcoded for mxfp4
torchao/prototype/qat/mxfp4.py
Outdated
| return grad_input, grad_weight, None, None, None | ||
|
|
||
|
|
||
| class MXFP4FakeQuantizedLinear(torch.nn.Linear): |
|
thanks for working on this! Made some initial comments inline. |
torchao/prototype/qat/mx.py
Outdated
| """ | ||
| MX (Microscaling) Quantization-Aware Training (QAT) support. | ||
|
|
||
| This module provides QAT support for the OCP Microscaling MX formats (MXFP4, MXFP8, MXFP6). |
There was a problem hiding this comment.
is there any demand for mxfp6 right now? if not, I'd prefer to not support in in QAT for now and add it later when there is demand.
There was a problem hiding this comment.
our users need it axolotl-ai-cloud/axolotl#3333 🥹
There was a problem hiding this comment.
the issue you linked mentions mxfp4. I am commenting about mxfp6, which is a different format.
There was a problem hiding this comment.
ahh my bad i miss read 🤕
torchao/prototype/qat/mx.py
Outdated
| self.weight_config = weight_config | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| if x.dim() == 3: |
There was a problem hiding this comment.
I think this should go in _MXQuantizedForwardFakeQuantizedBackward instead of here, and also not hardcode rank 3 but instead handle all ranks
orig_shape = x.shape
x_reshaped = x.view(-1, orig_shape[-1])
...
fq_reshaped = ...
fq_orig_shape = fq_reshaped.view(*orig_shape[:-1], -1)
return fq_orig_shape|
this looks great so far! I kicked off a CI run, and also made some comments inline. It would also be great to include a test plan, both for unit testing as well as any e2e QAT run results we can run with this. |
torchao/prototype/qat/mx.py
Outdated
|
|
||
| # Backwards compatibility aliases | ||
| MXFP4FakeQuantizeConfig = MXFakeQuantizeConfig | ||
| MXFP4FakeQuantizedLinear = MXFakeQuantizedLinear |
There was a problem hiding this comment.
I don't think you need these since these were never released
torchao/prototype/qat/__init__.py
Outdated
| "MXFakeQuantizeConfig", | ||
| "MXFakeQuantizedLinear", | ||
| "MXFP4FakeQuantizeConfig", | ||
| "MXFP4FakeQuantizedLinear", |
| - MXFP8: torch.float8_e4m3fn, torch.float8_e5m2 | ||
| - MXFP6: "fp6_e2m3", "fp6_e3m2" (string constants) | ||
|
|
||
| Key differences from NVFP4: |
There was a problem hiding this comment.
another key difference that's not mentioned here is NVFP4 does an extra per tensor scaling but MXFP4 doesn't
There was a problem hiding this comment.
seems like there are two docblocks in this file which detail the differences, can we keep one and delete the other one
torchao/prototype/qat/mx.py
Outdated
| """ | ||
|
|
||
| # Use Any type hint since elem_dtype can be torch.dtype or str (for fp6 formats) | ||
| elem_dtype: Any = field(default_factory=lambda: torch.float4_e2m1fn_x2) |
There was a problem hiding this comment.
maybe drop fp6 for now since they're not as popular? Then we can just use torch.dtype to express this
There was a problem hiding this comment.
also why use a field here instead of just assigning to the dtype directly?
There was a problem hiding this comment.
also nit: just rename this to dtype to be consistent
|
|
||
| # Check that weights have been updated | ||
| self.assertFalse(torch.allclose(mx_model[0].weight, initial_weight)) | ||
|
|
There was a problem hiding this comment.
Can you add a test like this one?
ao/test/quantization/test_qat.py
Line 2076 in a5f2693
Basically we should compare QAT forward against PTQ forward, since they're using the same kernels and so should match exactly
test/quantization/test_qat.py
Outdated
| "scaling_mode", | ||
| ["FLOOR", "RCEIL"], | ||
| ) | ||
| def test_mx_fake_quantized_linear_forward(self, bias, input_shape, scaling_mode): |
There was a problem hiding this comment.
This is testing specifically mxfp4 right, can we call that out explicitly in the function name? Also this test looks very similar to test_mx_fake_quantized_linear_forward_fp8. Any chance we can share some code?
|
here is the e2e test |
|
high level this looks good to me. Let's make sure CI is green, and also the PR summary should have a reproducible test plan. It would be ideal to also include some data from an e2e convergence run. I'll let @andrewor14 accept when he is also good with everything. |
andrewor14
left a comment
There was a problem hiding this comment.
Looks great, thanks for your work @ved1beta! Really appreciate all the extra unit tests you added.
By the way have you had a chance to test this in a real training job? Just wondering if you were able to compare these two:
(0) bf16 model -> fine-tune without QAT -> lm_eval
(1) bf16 model -> fine-tune without QAT -> quantize to mxfp4 -> lm_eval
(2) bf16 model -> fine-tune with QAT -> quantize to mxfp4 -> lm_eval
Ideally we'll see (0) > (2) > (1).
test/quantization/test_qat.py
Outdated
| # TODO: put this in a common test utils file | ||
| _CUDA_IS_AVAILABLE = torch.cuda.is_available() | ||
| _DEVICE = get_current_accelerator_device() | ||
| _MXFP4_TORCH_AVAILABLE = torch_version_at_least("2.10.0") |
There was a problem hiding this comment.
any reason to guard on pytorch 2.10? Seems like MXTensor only requires 2.8
There was a problem hiding this comment.
yes changed it too 2.8.0 perplexity told me its was 2.10.0 🥹
torchao/prototype/qat/mx.py
Outdated
|
|
||
| This is the OCP Microscaling MX variant which differs from NVFP4 in: | ||
| - Block size: 32 (default) vs 16 | ||
| - Scale format: E8M0 vs float8_e4m3fn |
There was a problem hiding this comment.
I see this comparison between MX and NVFP4 duplicated in multiple places. Probably don't need it in every doc block? Can we just keep it in the MXFakeQuantizeConfig since that's the most user facing?
torchao/prototype/qat/mx.py
Outdated
| - Block size: 32 (default, OCP standard) vs NVFP4's fixed 16 | ||
| - Scale format: E8M0 (float8_e8m0fnu) vs NVFP4's float8_e4m3fn | ||
| - Supports multiple scale calculation modes | ||
| - Supports multiple element dtypes (MXFP4, MXFP8) |
There was a problem hiding this comment.
same here, don't need this list again here
| kernel_preference: KernelPreference = KernelPreference.EMULATED | ||
|
|
||
| def __post_init__(self): | ||
| _validate_elem_dtype(self.dtype) |
There was a problem hiding this comment.
should we also _validate_kernel_preference here?
|
for this i was not able to test it with real training hardware constrains ! i feel the unit tests will be more than enough. similarly for ci everything should pass now
|
that makes sense. @andrewor14 , do you want to do a quick e2e test on this? |
| self.assertFalse(torch.allclose(mx_model[0].weight, initial_weight)) | ||
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.10.0"), "Need pytorch 2.10+") | ||
| @unittest.skipIf(not _MXFP4_TORCH_AVAILABLE, "Need pytorch 2.10+ for MXFP4") |
There was a problem hiding this comment.
Looks like these two are still referring to 2.10, make these 2.8?
There was a problem hiding this comment.
umm some tests were failing with 2.8 asking for 2.10
test/prototype/test_parq.py
Outdated
| assert self.embed_tokens.weight.shape == self.linear2.weight.shape | ||
| self.tie_weights() | ||
| self._tied_weights_keys["linear2.weight"] = "embed_tokens.weight" | ||
| self._tied_weights_keys = {"linear2.weight": "embed_tokens.weight"} |
There was a problem hiding this comment.
Can you revert these changes? I think we fixed it in main
Yeah no problem, I can test it. The code looks good to me. After the latest comments and tests passing I think we can merge this first. @ved1beta do you mind getting started with an axolotl PR that integrates this config? Thanks for all your hard work so far! |
on it , it was great working with you guys ❤️ will surly apply at pytorch some day 🫡 |
|
Thanks @ved1beta, merging this |
Title
#3547
Class
MXFakeQuantizedLinearClass
MXFakeQuantizeConfigClass
_MXQuantizedForwardFakeQuantizedBackwardTried following nvfp4 implementation
Tests included + e2e test