Skip to content

(prototype) add support for MXFP8 and MXFP4 QAT#3644

Merged
andrewor14 merged 23 commits intopytorch:mainfrom
ved1beta:mvfp4
Feb 3, 2026
Merged

(prototype) add support for MXFP8 and MXFP4 QAT#3644
andrewor14 merged 23 commits intopytorch:mainfrom
ved1beta:mvfp4

Conversation

@ved1beta
Copy link
Copy Markdown
Contributor

@ved1beta ved1beta commented Jan 15, 2026

Title
#3547
Class MXFakeQuantizedLinear
Class MXFakeQuantizeConfig
Class_MXQuantizedForwardFakeQuantizedBackward

Tried following nvfp4 implementation

Tests included + e2e test

"""MXFP4 QAT end-to-end training validation."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


class SimpleMLP(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, num_classes=32):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear3 = nn.Linear(hidden_dim, num_classes, bias=False)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        return self.linear3(x)


def create_data(num_samples=1000, input_dim=512, num_classes=32, device="cuda"):
    X = torch.randn(num_samples, input_dim, device=device)
    y = torch.randint(0, num_classes, (num_samples,), device=device)
    return DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)


def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    for data, target in loader:
        optimizer.zero_grad()
        loss = F.cross_entropy(model(data), target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def run_training(model, loader, epochs=5, lr=0.01):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    losses = []
    for _ in range(epochs):
        losses.append(train_epoch(model, loader, optimizer))
    return losses


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader = create_data(device=device)

    torch.manual_seed(42)
    baseline = SimpleMLP().to(device)
    baseline_losses = run_training(baseline, train_loader)

    from torchao.quantization import quantize_
    from torchao.quantization.qat import QATConfig
    from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig

    torch.manual_seed(42)
    qat_model = SimpleMLP().to(device)
    base_config = MXDynamicActivationMXWeightConfig(
        activation_dtype=torch.float4_e2m1fn_x2,
        weight_dtype=torch.float4_e2m1fn_x2,
    )
    quantize_(qat_model, QATConfig(base_config, step="prepare"))
    qat_losses = run_training(qat_model, train_loader)

    print("Epoch | Baseline | MXFP4 QAT")
    print("-" * 30)
    for i, (b, q) in enumerate(zip(baseline_losses, qat_losses)):
        print(f"  {i+1}   |  {b:.4f}  |  {q:.4f}")


if __name__ == "__main__":
    main()

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3644

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 02d271e with merge base 30fcb15 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 15, 2026
@jerryzh168 jerryzh168 requested a review from andrewor14 January 15, 2026 23:22
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_reconstruction(dtype, shape):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be already covered by existing tests, remove?

],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_scaling_modes(scaling_mode):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be already covered by existing tests, remove?



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_fake_quantize_config():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to test/quantization/test_qat.py

"scaling_mode",
[ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL],
)
def test_mxfp4_fake_quantized_linear_forward(bias, input_shape, scaling_mode):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to test/quantization/test_qat.py


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("bias", [True, False])
def test_mxfp4_fake_quantized_linear_backward(bias):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to test/quantization/test_qat.py



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_fake_quantized_linear_to_linear():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to test/quantization/test_qat.py



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_vs_nvfp4_block_size():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already covered by existing tests, remove?



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_config_error_handling():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to qat test file

],
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
)
def test_mxfp4_matmul_sqnr(shapes):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to qat test file



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mxfp4_training_simulation():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to qat test file



@dataclass
class MXFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should handle mxfp8 and mxfp4 (see MXTensor) instead of being hardcoded for mxfp4

return grad_input, grad_weight, None, None, None


class MXFP4FakeQuantizedLinear(torch.nn.Linear):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MX instead of MXFP4

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jan 16, 2026

thanks for working on this! Made some initial comments inline.

@ved1beta ved1beta requested a review from vkuzo January 16, 2026 16:38
"""
MX (Microscaling) Quantization-Aware Training (QAT) support.

This module provides QAT support for the OCP Microscaling MX formats (MXFP4, MXFP8, MXFP6).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any demand for mxfp6 right now? if not, I'd prefer to not support in in QAT for now and add it later when there is demand.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our users need it axolotl-ai-cloud/axolotl#3333 🥹

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue you linked mentions mxfp4. I am commenting about mxfp6, which is a different format.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my bad i miss read 🤕

self.weight_config = weight_config

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 3:
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go in _MXQuantizedForwardFakeQuantizedBackward instead of here, and also not hardcode rank 3 but instead handle all ranks

orig_shape = x.shape
x_reshaped = x.view(-1, orig_shape[-1])
...
fq_reshaped = ...
fq_orig_shape = fq_reshaped.view(*orig_shape[:-1], -1)
return fq_orig_shape

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jan 20, 2026

this looks great so far! I kicked off a CI run, and also made some comments inline.

It would also be great to include a test plan, both for unit testing as well as any e2e QAT run results we can run with this.

@vkuzo vkuzo added the topic: new feature Use this tag if this PR adds a new feature label Jan 20, 2026
Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ved1beta, looks great overall. I agree with @vkuzo we can drop fp6 for now to simplify the changes. Just left mostly minor comments on the API/testing.


# Backwards compatibility aliases
MXFP4FakeQuantizeConfig = MXFakeQuantizeConfig
MXFP4FakeQuantizedLinear = MXFakeQuantizedLinear
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need these since these were never released

"MXFakeQuantizeConfig",
"MXFakeQuantizedLinear",
"MXFP4FakeQuantizeConfig",
"MXFP4FakeQuantizedLinear",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these?

- MXFP8: torch.float8_e4m3fn, torch.float8_e5m2
- MXFP6: "fp6_e2m3", "fp6_e3m2" (string constants)

Key differences from NVFP4:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another key difference that's not mentioned here is NVFP4 does an extra per tensor scaling but MXFP4 doesn't

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like there are two docblocks in this file which detail the differences, can we keep one and delete the other one

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping on this

"""

# Use Any type hint since elem_dtype can be torch.dtype or str (for fp6 formats)
elem_dtype: Any = field(default_factory=lambda: torch.float4_e2m1fn_x2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe drop fp6 for now since they're not as popular? Then we can just use torch.dtype to express this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why use a field here instead of just assigning to the dtype directly?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit: just rename this to dtype to be consistent


# Check that weights have been updated
self.assertFalse(torch.allclose(mx_model[0].weight, initial_weight))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test like this one?

def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):

Basically we should compare QAT forward against PTQ forward, since they're using the same kernels and so should match exactly

"scaling_mode",
["FLOOR", "RCEIL"],
)
def test_mx_fake_quantized_linear_forward(self, bias, input_shape, scaling_mode):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing specifically mxfp4 right, can we call that out explicitly in the function name? Also this test looks very similar to test_mx_fake_quantized_linear_forward_fp8. Any chance we can share some code?

@ved1beta ved1beta requested a review from andrewor14 January 20, 2026 20:01
@ved1beta
Copy link
Copy Markdown
Contributor Author

here is the e2e test

"""MXFP4 QAT end-to-end training validation."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


class SimpleMLP(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, num_classes=32):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear3 = nn.Linear(hidden_dim, num_classes, bias=False)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        return self.linear3(x)


def create_data(num_samples=1000, input_dim=512, num_classes=32, device="cuda"):
    X = torch.randn(num_samples, input_dim, device=device)
    y = torch.randint(0, num_classes, (num_samples,), device=device)
    return DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)


def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    for data, target in loader:
        optimizer.zero_grad()
        loss = F.cross_entropy(model(data), target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def run_training(model, loader, epochs=5, lr=0.01):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    losses = []
    for _ in range(epochs):
        losses.append(train_epoch(model, loader, optimizer))
    return losses


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader = create_data(device=device)

    torch.manual_seed(42)
    baseline = SimpleMLP().to(device)
    baseline_losses = run_training(baseline, train_loader)

    from torchao.quantization import quantize_
    from torchao.quantization.qat import QATConfig
    from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig

    torch.manual_seed(42)
    qat_model = SimpleMLP().to(device)
    base_config = MXDynamicActivationMXWeightConfig(
        activation_dtype=torch.float4_e2m1fn_x2,
        weight_dtype=torch.float4_e2m1fn_x2,
    )
    quantize_(qat_model, QATConfig(base_config, step="prepare"))
    qat_losses = run_training(qat_model, train_loader)

    print("Epoch | Baseline | MXFP4 QAT")
    print("-" * 30)
    for i, (b, q) in enumerate(zip(baseline_losses, qat_losses)):
        print(f"  {i+1}   |  {b:.4f}  |  {q:.4f}")


if __name__ == "__main__":
    main()

@ved1beta ved1beta requested a review from vkuzo January 21, 2026 15:55
@vkuzo vkuzo changed the title mvfp4 support ao add support for MXFP8 and MXFP4 QAT Jan 22, 2026
@vkuzo vkuzo changed the title add support for MXFP8 and MXFP4 QAT (prototype) add support for MXFP8 and MXFP4 QAT Jan 22, 2026
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jan 22, 2026

high level this looks good to me. Let's make sure CI is green, and also the PR summary should have a reproducible test plan. It would be ideal to also include some data from an e2e convergence run.

I'll let @andrewor14 accept when he is also good with everything.

Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for your work @ved1beta! Really appreciate all the extra unit tests you added.

By the way have you had a chance to test this in a real training job? Just wondering if you were able to compare these two:

(0) bf16 model -> fine-tune without QAT -> lm_eval
(1) bf16 model -> fine-tune without QAT -> quantize to mxfp4 -> lm_eval
(2) bf16 model -> fine-tune with QAT -> quantize to mxfp4 -> lm_eval

Ideally we'll see (0) > (2) > (1).

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
_DEVICE = get_current_accelerator_device()
_MXFP4_TORCH_AVAILABLE = torch_version_at_least("2.10.0")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to guard on pytorch 2.10? Seems like MXTensor only requires 2.8

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes changed it too 2.8.0 perplexity told me its was 2.10.0 🥹


This is the OCP Microscaling MX variant which differs from NVFP4 in:
- Block size: 32 (default) vs 16
- Scale format: E8M0 vs float8_e4m3fn
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this comparison between MX and NVFP4 duplicated in multiple places. Probably don't need it in every doc block? Can we just keep it in the MXFakeQuantizeConfig since that's the most user facing?

- Block size: 32 (default, OCP standard) vs NVFP4's fixed 16
- Scale format: E8M0 (float8_e8m0fnu) vs NVFP4's float8_e4m3fn
- Supports multiple scale calculation modes
- Supports multiple element dtypes (MXFP4, MXFP8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, don't need this list again here

kernel_preference: KernelPreference = KernelPreference.EMULATED

def __post_init__(self):
_validate_elem_dtype(self.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also _validate_kernel_preference here?

@ved1beta
Copy link
Copy Markdown
Contributor Author

ved1beta commented Jan 27, 2026

for this i was not able to test it with real training hardware constrains ! i feel the unit tests will be more than enough. similarly for ci everything should pass now

Looks great, thanks for your work @ved1beta! Really appreciate all the extra unit tests you added.
By the way have you had a chance to test this in a real training job? Just wondering if you were able to compare these two

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jan 27, 2026

for this i was not able to test it with real training hardware constrains ! i feel the unit tests will be more than enough. similarly for ci everything should pass now

that makes sense. @andrewor14 , do you want to do a quick e2e test on this?

self.assertFalse(torch.allclose(mx_model[0].weight, initial_weight))

@unittest.skipIf(not torch_version_at_least("2.10.0"), "Need pytorch 2.10+")
@unittest.skipIf(not _MXFP4_TORCH_AVAILABLE, "Need pytorch 2.10+ for MXFP4")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these two are still referring to 2.10, make these 2.8?

Copy link
Copy Markdown
Contributor Author

@ved1beta ved1beta Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm some tests were failing with 2.8 asking for 2.10

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use PyTorch 2.10

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

donee'

assert self.embed_tokens.weight.shape == self.linear2.weight.shape
self.tie_weights()
self._tied_weights_keys["linear2.weight"] = "embed_tokens.weight"
self._tied_weights_keys = {"linear2.weight": "embed_tokens.weight"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert these changes? I think we fixed it in main

@andrewor14
Copy link
Copy Markdown
Contributor

for this i was not able to test it with real training hardware constrains ! i feel the unit tests will be more than enough. similarly for ci everything should pass now

Yeah no problem, I can test it. The code looks good to me. After the latest comments and tests passing I think we can merge this first. @ved1beta do you mind getting started with an axolotl PR that integrates this config? Thanks for all your hard work so far!

@ved1beta
Copy link
Copy Markdown
Contributor Author

do you mind getting started with an axolotl PR that integrates this config?

on it , it was great working with you guys ❤️ will surly apply at pytorch some day 🫡

@ved1beta ved1beta requested a review from andrewor14 February 3, 2026 09:17
@andrewor14
Copy link
Copy Markdown
Contributor

Thanks @ved1beta, merging this

@andrewor14 andrewor14 merged commit 82e58a6 into pytorch:main Feb 3, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants