diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index fd4939232f..c1d63c48cc 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq( quantize_(m, QATConfig(base_config, step="prepare"), filter_fn) out_prepared = m(*example_inputs) prepare_sqnr = compute_error(out_prepared, out_baseline) - self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr) # compare convert @@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool): self._test_quantize_api_against_ptq( NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), - target_prepare_sqnr=12, + target_prepare_sqnr=float("inf"), target_convert_sqnr=float("inf"), ) + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @parametrize("use_per_tensor_scale", [True, False]) def test_qat_nvfp4(self, use_per_tensor_scale: bool): """ Test QAT with `NVFP4FakeQuantizeConfig`. """ + from torchao.prototype.mx_formats import NVFP4InferenceConfig from torchao.prototype.qat import NVFP4FakeQuantizeConfig torch.manual_seed(self.SEED) m = M().cuda() baseline_model = copy.deepcopy(m) + quantize_( + baseline_model, + NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), + ) qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), @@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): out = m(*x) baseline_out = baseline_model(*x) sqnr = compute_error(out, baseline_out).item() - self.assertGreater(sqnr, 24) + self.assertGreaterEqual(sqnr, float("inf")) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 799b373dd6..c22f7793bb 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -771,29 +771,6 @@ def nvfp4_quantize( AssertionError: If input dtype is not supported, tensor size is not divisible by block_size, tensor is not contiguous, or block_size != 16 """ - return _nvfp4_quantize(data_hp, block_size, per_tensor_scale) - - -class _Float8Round(torch.autograd.Function): - """ - Cast a tensor to float8 and back to float32 with backward STE. - """ - - @staticmethod - def forward(ctx, x: torch.Tensor) -> torch.Tensor: - return x.to(torch.float8_e4m3fn).to(torch.float32) - - @staticmethod - def backward(ctx, gy: torch.Tensor) -> torch.Tensor: - return gy - - -def _nvfp4_quantize( - data_hp: torch.Tensor, - block_size: int = 16, - per_tensor_scale: Optional[torch.Tensor] = None, - skip_dtype_cast_and_packing: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: assert data_hp.dtype in (torch.bfloat16, torch.float), ( f"{data_hp.dtype} not supported" ) @@ -801,7 +778,6 @@ def _nvfp4_quantize( assert data_hp.is_contiguous(), "Only support contiguous data for now" assert block_size == 16, "NVFP4 requires block_size=16" - orig_dtype = data_hp.dtype orig_shape = data_hp.shape # Convert to float32 early for consistent precision with Triton implementation data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size) @@ -813,8 +789,10 @@ def _nvfp4_quantize( out_scales = None if per_tensor_scale is None: # We are doing single level scaling - block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX) - block_scale_fp32 = _Float8Round.apply(block_scale_fp8) + block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to( + torch.float8_e4m3fn + ) + block_scale_fp32 = block_scale_fp8.to(torch.float32) data_scaled = data_hp / block_scale_fp32.unsqueeze(-1) out_scales = block_scale_fp8 else: @@ -826,8 +804,8 @@ def _nvfp4_quantize( scaled_block_scales = block_scale_fp32 / per_tensor_scale scaled_block_scales_fp8 = torch.clamp( scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX - ) - scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8) + ).to(torch.float8_e4m3fn) + scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale # To apply to data total_scale = per_tensor_scale * scaled_block_scales_fp32 @@ -836,11 +814,8 @@ def _nvfp4_quantize( data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) data_scaled = data_scaled.view(orig_shape) - if skip_dtype_cast_and_packing: - return out_scales.to(torch.float32), data_scaled.to(orig_dtype) - else: - data_lp = f32_to_f4_unpacked(data_scaled) - # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' - # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) - data_lp = pack_uint4(data_lp) - return out_scales.to(torch.float8_e4m3fn), data_lp + data_lp = f32_to_f4_unpacked(data_scaled) + # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) + data_lp = pack_uint4(data_lp) + return out_scales, data_lp diff --git a/torchao/prototype/qat/__init__.py b/torchao/prototype/qat/__init__.py index 0727a1c673..b6b5825c5d 100644 --- a/torchao/prototype/qat/__init__.py +++ b/torchao/prototype/qat/__init__.py @@ -3,10 +3,10 @@ from .nvfp4 import ( NVFP4FakeQuantizeConfig, - NVFP4FakeQuantizer, + NVFP4FakeQuantizedLinear, ) __all__ = [ "NVFP4FakeQuantizeConfig", - "NVFP4FakeQuantizer", + "NVFP4FakeQuantizedLinear", ] diff --git a/torchao/prototype/qat/nvfp4.py b/torchao/prototype/qat/nvfp4.py index ed709dba1d..396389d22e 100644 --- a/torchao/prototype/qat/nvfp4.py +++ b/torchao/prototype/qat/nvfp4.py @@ -1,15 +1,14 @@ from dataclasses import dataclass +from typing import Optional import torch from torchao.prototype.mx_formats.nvfp4_tensor import ( - _nvfp4_quantize, + NVFP4Tensor, + _addmm_nvfp4_dispatch, per_tensor_amax_to_scale, ) -from torchao.quantization.qat import ( - FakeQuantizeConfigBase, - FakeQuantizerBase, -) +from torchao.quantization.qat import FakeQuantizeConfigBase @dataclass @@ -23,47 +22,166 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase): Args: use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling after the initial fp8 (e4m3) block-wise scaling (default True) + use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format + use_triton_kernel (bool): Whether to use triton kernels during fake quantization """ use_per_tensor_scale: bool = True + use_swizzled_scales: bool = False + use_triton_kernel: bool = False + + +# TODO: support emulation on non-Blackwell GPUs +class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function): + """ + Autograd function for NVFP4 quantization + addmm in low precision during forward, + and fake quantization in high precision during backward. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + activation_config: NVFP4FakeQuantizeConfig, + weight_config: NVFP4FakeQuantizeConfig, + ) -> torch.Tensor: + # quantize input activations + if activation_config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(_input)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = None + _input = NVFP4Tensor.to_nvfp4( + _input, + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=activation_config.use_swizzled_scales, + use_triton_kernel=activation_config.use_triton_kernel, + ) + + # quantize weights + if weight_config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(weight)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = None + weight = NVFP4Tensor.to_nvfp4( + weight, + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=weight_config.use_swizzled_scales, + use_triton_kernel=False, + ) + # Follow `NVFP4InferenceConfig`, always use traditional construction + # for weights and set `use_triton_kernel` afterwards + weight.use_triton_kernel = weight_config.use_triton_kernel -class NVFP4FakeQuantizer(FakeQuantizerBase): + ctx.save_for_backward(_input, weight) + + return _addmm_nvfp4_dispatch( + _input, + weight.t(), + None, # aten_op, not used + bias, + ) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + _input, weight = ctx.saved_tensors + assert isinstance(_input, NVFP4Tensor) + assert isinstance(weight, NVFP4Tensor) + _input = _input.to_dtype(_input._orig_dtype) + weight = weight.to_dtype(weight._orig_dtype) + grad_input = torch.mm(grad_output, weight) + grad_weight = torch.mm(grad_output.t(), _input) + return grad_input, grad_weight, None, None, None + + +class NVFP4FakeQuantizedLinear(torch.nn.Linear): """ - (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config. + Linear module for fake quantized NVFP4 weights and/or activations. + + The forward pass follows quantization and addmm numerics in `NVFP4Tensor` + in lower precision exactly, while the backward pass uses dequantize + (fake quantized) values in high precision. + + Currently this is only applicable on Blackwell and future generations. + See https://github.com/pytorch/ao/issues/3102 for more details. + + Example usage:: + + from torchao.quantization import quantize_ + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + base_config = NVFP4InferenceConfig() + quantize_(model, QATConfig(base_config, step="prepare")) + # Model contains `NVFP4FakeQuantizedLinear` now + + train_loop(model) + quantize_(model, QATConfig(base_config, step="convert")) + # Model contains `nn.Linear` with `NVFP4Tensor` weights now """ - def __init__(self, config: NVFP4FakeQuantizeConfig): - super().__init__() - torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer") - self.config = config + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation_config: Optional[NVFP4FakeQuantizeConfig] = None, + weight_config: Optional[NVFP4FakeQuantizeConfig] = None, + *args, + **kwargs, + ): + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + if weight_config is None: + raise ValueError("Must specify `weight_config`") + if activation_config is None: + raise ValueError("Weight only NVFP4 QAT not supported yet") + self.activation_config = activation_config + self.weight_config = weight_config def forward(self, x: torch.Tensor) -> torch.Tensor: - block_size = 16 - original_shape = x.shape if x.dim() == 3: + batch_size = x.shape[0] x = x.view(-1, x.shape[-1]) - if self.config.use_per_tensor_scale: - tensor_amax = torch.max(torch.abs(x)) - per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) else: - per_tensor_scale = None + batch_size = None + fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply( + x, self.weight, self.bias, self.activation_config, self.weight_config + ) + assert fq.dtype == x.dtype + if batch_size is not None: + return fq.view(batch_size, -1, fq.shape[-1]) + else: + return fq - # quantize - scale, q = _nvfp4_quantize( - x, - block_size=block_size, - per_tensor_scale=per_tensor_scale, - skip_dtype_cast_and_packing=True, + @classmethod + def from_linear( + cls, + mod: torch.nn.Linear, + activation_config: Optional[NVFP4FakeQuantizeConfig] = None, + weight_config: Optional[NVFP4FakeQuantizeConfig] = None, + ): + new_linear = NVFP4FakeQuantizedLinear( + mod.in_features, + mod.out_features, + mod.bias is not None, + activation_config=activation_config, + weight_config=weight_config, + device=mod.weight.device, + dtype=mod.weight.dtype, ) - if self.config.use_per_tensor_scale: - scale = scale * per_tensor_scale - assert q.dtype == x.dtype - assert scale.dtype == torch.float32 - - # dequantize - M, K = q.shape[0], q.shape[1] - q = q.view(M, K // block_size, block_size) - scale = scale.view(M, K // block_size, 1) - dq = q * scale - return dq.view(original_shape).to(x.dtype) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if mod.weight.device != torch.device("meta"): + new_linear.weight = mod.weight + new_linear.bias = mod.bias + return new_linear diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 1287126bac..a7d89850aa 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -208,7 +208,24 @@ def _qat_config_transform( act_config = config.activation_config weight_config = config.weight_config if isinstance(module, torch.nn.Linear): - return FakeQuantizedLinear.from_linear(module, act_config, weight_config) + # TODO: rewrite this using a registration API so + # specific quantization schemes do not leak here + from torchao.prototype.qat import ( + NVFP4FakeQuantizeConfig, + NVFP4FakeQuantizedLinear, + ) + + if isinstance(weight_config, NVFP4FakeQuantizeConfig): + assert act_config is None or isinstance( + act_config, NVFP4FakeQuantizeConfig + ) + return NVFP4FakeQuantizedLinear.from_linear( + module, act_config, weight_config + ) + else: + return FakeQuantizedLinear.from_linear( + module, act_config, weight_config + ) elif isinstance(module, torch.nn.Embedding): if act_config is not None: raise ValueError( diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 336a419af5..3a1c7c78f1 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -444,12 +444,16 @@ def _infer_fake_quantize_configs( elif isinstance(base_config, NVFP4InferenceConfig): if NVFP4MMConfig.DYNAMIC: act_config = NVFP4FakeQuantizeConfig( - use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale + use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale, + use_swizzled_scales=False, + use_triton_kernel=False, ) else: act_config = None weight_config = NVFP4FakeQuantizeConfig( - use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale + use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale, + use_swizzled_scales=True, + use_triton_kernel=base_config.use_triton_kernel, ) elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig): assert base_config.version >= 2, "Only version 2+ is supported" diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 9c06264be8..595dafaba8 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -60,20 +60,12 @@ def __repr__(self) -> str: @staticmethod def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": - # TODO: rewrite using registration API so we don't need to import here - from torchao.prototype.qat import ( - NVFP4FakeQuantizeConfig, - NVFP4FakeQuantizer, - ) - if isinstance(config, IntxFakeQuantizeConfig): return IntxFakeQuantizer(config) elif isinstance(config, Int4WeightFakeQuantizeConfig): return Int4WeightFakeQuantizer(config) elif isinstance(config, Float8FakeQuantizeConfig): return Float8FakeQuantizer(config) - elif isinstance(config, NVFP4FakeQuantizeConfig): - return NVFP4FakeQuantizer(config) else: raise ValueError(f"Unknown config type: {config}")