From 8f09c9149acade4d04aab849e86f244843594598 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 10 Dec 2025 12:44:10 -0800 Subject: [PATCH] Fix NVFP4 QAT backward typo **Summary:** Fix `to_dtype` -> `dequantize`. This was broken in https://github.com/pytorch/ao/pull/3169. **Test Plan:** ``` python test/quantization/test_qat.py -k nvfp4_training ``` --- test/quantization/test_qat.py | 34 ++++++++++++++++++++++++++++++++++ torchao/prototype/qat/nvfp4.py | 4 ++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 6a84771657..1bcc5e3349 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -2148,6 +2148,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): sqnr = compute_error(out, baseline_out).item() self.assertGreaterEqual(sqnr, float("inf")) + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize("use_per_tensor_scale", [True, False]) + def test_qat_nvfp4_training(self, use_per_tensor_scale: bool): + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig + + torch.manual_seed(self.SEED) + m = M().cuda() + base_config = NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=use_per_tensor_scale + ) + quantize_(m, QATConfig(base_config, step="prepare")) + + # Simulate training + num_steps = 10 + optimizer = torch.optim.SGD( + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + loss_fn = torch.nn.CrossEntropyLoss() + for i in range(num_steps): + example_inputs = m.example_inputs("cuda") + prev_weight = copy.deepcopy(m.linear1.weight) + optimizer.zero_grad() + target = torch.randn(1, 512).float().cuda() + out = m(*example_inputs) + loss = loss_fn(out, target) + loss.backward() + optimizer.step() + # Assert that weights have valid gradients and are being updated + new_weight = m.linear1.weight + self.assertIsNotNone(new_weight.grad) + self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) + self.assertFalse(torch.equal(new_weight, prev_weight)) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" diff --git a/torchao/prototype/qat/nvfp4.py b/torchao/prototype/qat/nvfp4.py index 9d635faed6..4eec4bda22 100644 --- a/torchao/prototype/qat/nvfp4.py +++ b/torchao/prototype/qat/nvfp4.py @@ -91,8 +91,8 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: _input, weight = ctx.saved_tensors assert isinstance(_input, NVFP4Tensor) assert isinstance(weight, NVFP4Tensor) - _input = _input.to_dtype(_input._orig_dtype) - weight = weight.to_dtype(weight._orig_dtype) + _input = _input.dequantize(_input._orig_dtype) + weight = weight.dequantize(weight._orig_dtype) grad_input = torch.mm(grad_output, weight) grad_weight = torch.mm(grad_output.t(), _input) return grad_input, grad_weight, None, None, None