Skip to content

Extend NVFP4 QAT to support emulation #3102

@andrewor14

Description

@andrewor14

As of #3050, NVFP4 QAT is only usable on Blackwell GPUs. This is because the forward pass reuses all the quantization primitives used in NVFP4Tensor in the PTQ path. However, a potentially important use case is to train the model on other GPU architectures (e.g. A100 or H100) using QAT, save the model, and load it on a different machine with B200 support for inference.

More details: The main known numerical discrepancy comes from these two paths:

# (1) Current implementation: only supported on B200 currently
# https://github.com/pytorch/ao/blob/main/torchao/prototype/qat/nvfp4.py#L82
_addmm_nvfp4_dispatch(nvfp4_input, nvfp4_weight.t(), None, bias)

# (2) More general implementation: probably works on A100 or H100, need to verify
# https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/linear.py#L115
orig_dtype = torch.bfloat16
fq_input = nvfp4_input.to_dtype(orig_dtype)
fq_weight = nvfp4_weight.to_dtype(orig_dtype)
F.linear(fq_input, fq_weight, bias)

(1) and (2) don't match currently. (1) gives inf SQNR compared to the PTQ flow, while (2) gives only ~21 for this QAT test. We want to restructure our code to (2) so we can reuse existing QAT components like FakeQuantizer, but keep the inf SQNR from (1).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions