-
Notifications
You must be signed in to change notification settings - Fork 351
Open
Description
As of #3050, NVFP4 QAT is only usable on Blackwell GPUs. This is because the forward pass reuses all the quantization primitives used in NVFP4Tensor in the PTQ path. However, a potentially important use case is to train the model on other GPU architectures (e.g. A100 or H100) using QAT, save the model, and load it on a different machine with B200 support for inference.
More details: The main known numerical discrepancy comes from these two paths:
# (1) Current implementation: only supported on B200 currently
# https://github.com/pytorch/ao/blob/main/torchao/prototype/qat/nvfp4.py#L82
_addmm_nvfp4_dispatch(nvfp4_input, nvfp4_weight.t(), None, bias)
# (2) More general implementation: probably works on A100 or H100, need to verify
# https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/linear.py#L115
orig_dtype = torch.bfloat16
fq_input = nvfp4_input.to_dtype(orig_dtype)
fq_weight = nvfp4_weight.to_dtype(orig_dtype)
F.linear(fq_input, fq_weight, bias)
(1) and (2) don't match currently. (1) gives inf SQNR compared to the PTQ flow, while (2) gives only ~21 for this QAT test. We want to restructure our code to (2) so we can reuse existing QAT components like FakeQuantizer, but keep the inf SQNR from (1).
Metadata
Metadata
Assignees
Labels
No labels