Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
sqnr = compute_error(out, baseline_out).item()
self.assertGreaterEqual(sqnr, float("inf"))

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@parametrize("use_per_tensor_scale", [True, False])
def test_qat_nvfp4_training(self, use_per_tensor_scale: bool):
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig

torch.manual_seed(self.SEED)
m = M().cuda()
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=use_per_tensor_scale
)
quantize_(m, QATConfig(base_config, step="prepare"))

# Simulate training
num_steps = 10
optimizer = torch.optim.SGD(
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(num_steps):
example_inputs = m.example_inputs("cuda")
prev_weight = copy.deepcopy(m.linear1.weight)
optimizer.zero_grad()
target = torch.randn(1, 512).float().cuda()
out = m(*example_inputs)
loss = loss_fn(out, target)
loss.backward()
optimizer.step()
# Assert that weights have valid gradients and are being updated
new_weight = m.linear1.weight
self.assertIsNotNone(new_weight.grad)
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/qat/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
_input, weight = ctx.saved_tensors
assert isinstance(_input, NVFP4Tensor)
assert isinstance(weight, NVFP4Tensor)
_input = _input.to_dtype(_input._orig_dtype)
weight = weight.to_dtype(weight._orig_dtype)
_input = _input.dequantize(_input._orig_dtype)
weight = weight.dequantize(weight._orig_dtype)
grad_input = torch.mm(grad_output, weight)
grad_weight = torch.mm(grad_output.t(), _input)
return grad_input, grad_weight, None, None, None
Expand Down
Loading