diff --git a/test/quantization/pt2e/test_learnable_fake_quantize.py b/test/quantization/pt2e/test_learnable_fake_quantize.py new file mode 100644 index 0000000000..c2ed3bf3d9 --- /dev/null +++ b/test/quantization/pt2e/test_learnable_fake_quantize.py @@ -0,0 +1,794 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import unittest + +import numpy as np +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import TestCase + +from torchao.quantization.pt2e.learnable_fake_quantize import ( + LearnableFakeQuantize, +) +from torchao.quantization.pt2e.observer import ( + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) + + +# Reference methods for fake quantize operations +def _fake_quantize_per_tensor_affine_reference( + X, scale, zero_point, quant_min, quant_max +): + """Reference implementation of per-tensor fake quantization.""" + dtype = X.dtype + res = ( + torch.clamp( + torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point), + quant_min, + quant_max, + ) + - zero_point + ) * scale + return res.to(dtype) + + +def _fake_quantize_per_tensor_affine_grad_reference( + dY, X, scale, zero_point, quant_min, quant_max +): + """Reference implementation of per-tensor fake quantization gradient.""" + dtype = X.dtype + Xq = torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point) + mask = (Xq >= quant_min) * (Xq <= quant_max) + res = torch.zeros_like(dY) + res[mask] = dY[mask] + return res.to(dtype) + + +def _fake_quantize_learnable_per_tensor_affine_grad_reference( + dY, X, scale, zero_point, quant_min, quant_max, device +): + """Reference implementation of learnable per-tensor fake quantization gradients.""" + zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item()) + Xq = torch.round(X * (1.0 / scale) + zero_point_rounded) + + indicate_small_scale = (Xq < quant_min).float().to(device) + indicate_big_scale = (Xq > quant_max).float().to(device) + indicate_middle_scale = ( + torch.ones(indicate_small_scale.shape).to(device) + - indicate_small_scale + - indicate_big_scale + ) + + indicate_saturate_zp = ((Xq < quant_min).float() + (Xq > quant_max).float()).to( + device + ) + indicate_unsaturate_zp = ( + torch.ones(indicate_saturate_zp.shape).to(device) - indicate_saturate_zp + ) + + Xq = Xq.clamp(quant_min, quant_max) + Xfq = (Xq - zero_point_rounded) * scale + + grad_small_scale = quant_min - zero_point_rounded + grad_big_scale = quant_max - zero_point_rounded + grad_middle_scale = ((Xfq - X) / scale).to(device) + + grad_saturate_zp = -scale.to(device) + grad_unsaturate_zp = 0 + + grad_scale = ( + indicate_small_scale * grad_small_scale + + indicate_big_scale * grad_big_scale + + indicate_middle_scale * grad_middle_scale + ) + grad_zp = ( + indicate_saturate_zp * grad_saturate_zp + + indicate_unsaturate_zp * grad_unsaturate_zp + ) + grad_X = _fake_quantize_per_tensor_affine_grad_reference( + dY, X, scale, zero_point, quant_min, quant_max + ).to(device) + + grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0) + grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0) + return grad_X, grad_scale, grad_zp + + +# Removed unused helper functions _get_tensor_min_max and _get_scale_zp +# These were not being used in any of the tests + + +NP_RANDOM_SEED = 19 +tolerance = 1e-6 + + +class TestLearnableFakeQuantize(TestCase): + """Test cases for LearnableFakeQuantize module.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + torch.manual_seed(42) + np.random.seed(NP_RANDOM_SEED) + + def test_initialization_per_tensor(self): + """Test initialization of LearnableFakeQuantize module for per-tensor quantization.""" + observer = MovingAverageMinMaxObserver + quant_min = 0 + quant_max = 255 + + lfq = LearnableFakeQuantize( + observer=observer, quant_min=quant_min, quant_max=quant_max + ) + + # Test that the module is properly initialized + self.assertEqual(lfq.quant_min, quant_min) + self.assertEqual(lfq.quant_max, quant_max) + # Initially scale and zero_point should be None + self.assertIsNone(lfq.scale) + self.assertIsNone(lfq.zero_point) + self.assertFalse(lfq._initialized) + + def test_initialization_per_channel(self): + """Test initialization of LearnableFakeQuantize module for per-channel quantization.""" + observer = MovingAveragePerChannelMinMaxObserver + quant_min = 0 + quant_max = 255 + + lfq = LearnableFakeQuantize( + observer=observer, quant_min=quant_min, quant_max=quant_max, ch_axis=0 + ) + + # Test that the module is properly initialized for per-channel + self.assertEqual(lfq.quant_min, quant_min) + self.assertEqual(lfq.quant_max, quant_max) + # Initially scale and zero_point should be None + self.assertIsNone(lfq.scale) + self.assertIsNone(lfq.zero_point) + self.assertFalse(lfq._initialized) + + def test_enable_range_learning(self): + """Test enabling range learning functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Initially learning should be disabled and scale/zero_point should be None + self.assertEqual(lfq.learning_enabled[0], 0) + self.assertIsNone(lfq.scale) + self.assertIsNone(lfq.zero_point) + + # Enable range learning + lfq.enable_range_learning() + + # Check that learning is enabled + self.assertEqual(lfq.learning_enabled[0], 1) + # scale and zero_point are still None until first forward pass + self.assertIsNone(lfq.scale) + self.assertIsNone(lfq.zero_point) + self.assertEqual(lfq.fake_quant_enabled[0], 1) + self.assertEqual(lfq.observer_enabled[0], 0) + + def test_disable_range_learning(self): + """Test disabling range learning functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Enable range learning first + lfq.enable_range_learning() + + # Run a forward pass to initialize scale and zero_point + X = torch.randn(4, 4) + lfq(X) + + # Then disable range learning + lfq.disable_range_learning() + + # Check that learning is disabled + self.assertEqual(lfq.learning_enabled[0], 0) + self.assertFalse(lfq.scale.requires_grad) + self.assertFalse(lfq.zero_point.requires_grad) + + def test_enable_observer(self): + """Test enabling observer functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Enable observer + lfq.enable_observer(True) + + # Check that observer is enabled and learning is disabled + self.assertEqual(lfq.observer_enabled[0], 1) + self.assertEqual(lfq.learning_enabled[0], 0) + + # Test disable observer + lfq.disable_observer() + self.assertEqual(lfq.observer_enabled[0], 0) + + def test_fake_quant_control(self): + """Test fake quantization control functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Test enable_fake_quant + lfq.enable_fake_quant(True) + self.assertEqual(lfq.fake_quant_enabled[0], 1) + + # Test disable_fake_quant + lfq.disable_fake_quant() + self.assertEqual(lfq.fake_quant_enabled[0], 0) + + def test_calculate_qparams(self): + """Test calculation of quantization parameters.""" + observer = MovingAverageMinMaxObserver + scale_val = 0.1 + zero_point_val = 128.0 + quant_min = 0 + quant_max = 255 + + lfq = LearnableFakeQuantize( + observer=observer, quant_min=quant_min, quant_max=quant_max + ) + + # Initialize parameters by running a forward pass first + X = torch.randn(4, 4) + lfq(X) + + # Manually set the scale and zero_point values for testing + lfq.scale.data.fill_(scale_val) + lfq.zero_point.data.fill_(zero_point_val) + + scale, zero_point = lfq.calculate_qparams() + + # Check that scale is properly clamped and zero_point is properly rounded/clamped + self.assertGreaterEqual(scale.item(), lfq.eps.item()) + self.assertGreaterEqual(zero_point.item(), quant_min) + self.assertLessEqual(zero_point.item(), quant_max) + self.assertEqual(zero_point.dtype, torch.long) + + def test_forward_observer_enabled(self): + """Test forward pass with observer enabled.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Enable observer + lfq.enable_observer(True) + + # Create test input + X = torch.randn(4, 4) * 10 + + # Forward pass + output = lfq(X) + + # Check that output has correct shape and type + self.assertEqual(output.shape, X.shape) + self.assertEqual(output.dtype, X.dtype) + + # Check that scale and zero_point have been initialized + self.assertIsNotNone(lfq.scale) + self.assertIsNotNone(lfq.zero_point) + + def test_forward_learning_enabled(self): + """Test forward pass with range learning enabled.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Enable range learning + lfq.enable_range_learning() + + # Create test input that requires grad + X = torch.randn(4, 4, requires_grad=True) * 10 + + # Run forward pass to initialize learnable fake quantizers + output = lfq(X) + + # Check that output has correct shape and type + self.assertEqual(output.shape, X.shape) + self.assertEqual(output.dtype, X.dtype) + + # Check that gradients can flow through + loss = output.sum() + loss.backward() + # Note: X may not have grad if it's not a leaf tensor; focus on testing quantizer gradients + self.assertIsNotNone(lfq.scale.grad) + self.assertIsNotNone(lfq.zero_point.grad) + + def test_forward_fake_quant_disabled(self): + """Test forward pass with fake quantization disabled.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Disable fake quantization + lfq.disable_fake_quant() + + # Create test input + X = torch.randn(4, 4) * 10 + + # Forward pass + output = lfq(X) + + # Output should be identical to input when fake quantization is disabled + torch.testing.assert_close(output, X) + + def test_symmetric_quantization(self): + """Test symmetric quantization behavior.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Enable fake quantization + lfq.enable_fake_quant(True) + + # Create test input + X = torch.randn(4, 4) * 10 + + # Forward pass to initialize parameters + lfq(X) + + # For symmetric quantization, zero_point should be zero + # (Note: This test assumes the implementation handles symmetric mode) + self.assertIsNotNone(lfq.zero_point) + + def test_per_channel_quantization(self): + """Test per-channel quantization functionality.""" + observer = MovingAveragePerChannelMinMaxObserver + channel_len = 4 + + lfq = LearnableFakeQuantize(observer=observer, ch_axis=0) + + # Enable fake quantization + lfq.enable_fake_quant(True) + + # Create test input with correct channel dimension + X = torch.randn(channel_len, 8) * 10 + + # Forward pass + output = lfq(X) + + # Check that output has correct shape + self.assertEqual(output.shape, X.shape) + self.assertEqual(lfq.scale.shape[0], channel_len) + self.assertEqual(lfq.zero_point.shape[0], channel_len) + + def test_gradient_scaling(self): + """Test gradient scaling functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer, use_grad_scaling=True) + + # Enable range learning + lfq.enable_range_learning() + + # Create test input that requires grad + X = torch.randn(4, 4, requires_grad=True) * 10 + + # Run forward pass to initialize learnable fake quantizers + output = lfq(X) + + # Check that gradients can flow through + loss = output.sum() + loss.backward() + self.assertIsNotNone(lfq.scale.grad) + self.assertIsNotNone(lfq.zero_point.grad) + + def test_error_conditions(self): + """Test error conditions during initialization.""" + observer = MovingAverageMinMaxObserver + + # Test quant_min >= quant_max + with self.assertRaises(AssertionError): + LearnableFakeQuantize(observer=observer, quant_min=255, quant_max=0) + + def test_state_persistence(self): + """Test that module state is properly maintained across forward passes.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer) + + # Initial state + initial_fake_quant_enabled = lfq.fake_quant_enabled[0].item() + initial_observer_enabled = lfq.observer_enabled[0].item() + initial_learning_enabled = lfq.learning_enabled[0].item() + + # Forward pass + X = torch.randn(4, 4) + lfq(X) # We don't need to store the output, just run the forward pass + + # State should be preserved + self.assertEqual(lfq.fake_quant_enabled[0].item(), initial_fake_quant_enabled) + self.assertEqual(lfq.observer_enabled[0].item(), initial_observer_enabled) + self.assertEqual(lfq.learning_enabled[0].item(), initial_learning_enabled) + + def test_learnable_forward_per_tensor(self): + """Test learnable forward pass for per-tensor quantization.""" + X = torch.randn(5, 5, dtype=torch.float32) + scale_base = torch.tensor([0.1], dtype=torch.float32) + zero_point_base = torch.tensor([128.0], dtype=torch.float32) + + for n_bits in (4, 8): + quant_min, quant_max = 0, 2**n_bits - 1 + + X_test = X.clone().float() + scale = scale_base.clone() + zero_point = zero_point_base.clamp(quant_min, quant_max) + + Y = _fake_quantize_per_tensor_affine_reference( + X_test, scale, zero_point, quant_min, quant_max + ) + + for grad_factor in [0.1, 1.0, 10.0]: + Y_prime = torch._fake_quantize_learnable_per_tensor_affine( + X_test, scale, zero_point, quant_min, quant_max, grad_factor + ) + self.assertTrue( + torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), + "Expected kernel forward function to have results match the reference forward function", + ) + + def test_learnable_backward_per_tensor(self): + """Test learnable backward pass for per-tensor quantization.""" + X = torch.randn(5, 5, dtype=torch.float32) + scale_base = torch.tensor([0.1], dtype=torch.float32) + zero_point_base = torch.tensor([128.0], dtype=torch.float32) + device = "cpu" + + for n_bits in (4, 8): + quant_min, quant_max = 0, 2**n_bits - 1 + + X_test = X.clone().float() + X_test.requires_grad_() + scale = scale_base.clone() + scale.requires_grad_() + zero_point = zero_point_base.clone().clamp(quant_min, quant_max) + zero_point.requires_grad_() + + for grad_factor in [0.1, 1.0, 10.0]: + Y_prime = torch._fake_quantize_learnable_per_tensor_affine( + X_test, scale, zero_point, quant_min, quant_max, grad_factor + ) + dout = torch.rand_like(X_test, dtype=torch.float) + dX, dScale, dZeroPoint = ( + _fake_quantize_learnable_per_tensor_affine_grad_reference( + dout, X_test, scale, zero_point, quant_min, quant_max, device + ) + ) + Y_prime.backward(dout) + + expected_dX = dX.detach() + actual_dX = X_test.grad.detach() + expected_dScale = dScale.detach() + actual_dScale = scale.grad.detach() + expected_dZeroPoint = dZeroPoint.detach() + actual_dZeroPoint = zero_point.grad.detach() + + self.assertTrue( + torch.allclose( + expected_dX, actual_dX, rtol=tolerance, atol=tolerance + ), + "Expected dX to match X.grad", + ) + self.assertTrue( + torch.allclose( + expected_dScale * grad_factor, + actual_dScale, + rtol=tolerance, + atol=tolerance, + ), + "Expected dScale to match scale.grad", + ) + self.assertTrue( + torch.allclose( + expected_dZeroPoint * grad_factor, + actual_dZeroPoint, + rtol=tolerance, + atol=tolerance, + ), + "Expected dZeroPoint to match zero_point.grad", + ) + X_test.grad.data.zero_() + scale.grad.data.zero_() + zero_point.grad.data.zero_() + + def test_fake_quant_and_observer_control(self): + """Test fake quantization and observer control functionality.""" + observer = MovingAverageMinMaxObserver + lfq = LearnableFakeQuantize(observer=observer, quant_min=0, quant_max=255) + + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + + # Output of fake quant should not be identical to input initially + Y = lfq(X) + # Note: Initially output might be close to input if scale is 1.0 and zero_point is 0.0 + # Let's just check the shape and type are correct + self.assertEqual(Y.shape, X.shape) + self.assertEqual(Y.dtype, X.dtype) + + # Disable fake quantization + lfq.disable_fake_quant() + X = torch.rand(20, 10, dtype=torch.float32) + Y = lfq(X) + # Fake quant is disabled, output should be identical to input + torch.testing.assert_close(Y, X) + + # Disable observer and enable fake quant + lfq.disable_observer() + lfq.enable_fake_quant(True) + + # Store current scale and zero_point + scale = lfq.scale.detach().clone() + zero_point = lfq.zero_point.detach().clone() + + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = lfq(X) + self.assertNotEqual(Y.shape, torch.Size([0])) # Output should exist + # Observer is disabled, scale and zero-point should not change + torch.testing.assert_close(lfq.scale, scale) + torch.testing.assert_close(lfq.zero_point, zero_point) + + # Enable observer + lfq.enable_observer(True) + Y = lfq(X) + self.assertNotEqual(Y.shape, torch.Size([0])) # Output should exist + # Observer is enabled, scale and zero-point may be different + # (though they might not change significantly with this data) + + +class TestLearnableFakeQuantizeIntegration(TestCase): + """Integration tests for LearnableFakeQuantize with neural network modules.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + torch.manual_seed(42) + + def test_integration_with_linear_layer(self): + """Test LearnableFakeQuantize integration with linear layer.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.fake_quant = LearnableFakeQuantize( + observer=MovingAverageMinMaxObserver + ) + + def forward(self, x): + x = self.linear(x) + x = self.fake_quant(x) + return x + + model = SimpleModel() + model.fake_quant.enable_range_learning() + + x = torch.randn(4, 10) + # Run model forward to initialize learnable fake quantizers + output = model(x) + + self.assertEqual(output.shape, (4, 5)) + + # Test backward pass + loss = output.sum() + loss.backward() + + # Check that all gradients exist + self.assertIsNotNone(model.linear.weight.grad) + self.assertIsNotNone(model.fake_quant.scale.grad) + self.assertIsNotNone(model.fake_quant.zero_point.grad) + + def test_multiple_fake_quant_modules(self): + """Test multiple LearnableFakeQuantize modules in one model.""" + + class MultiQuantModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 8) + self.fake_quant1 = LearnableFakeQuantize( + observer=MovingAverageMinMaxObserver + ) + self.linear2 = nn.Linear(8, 5) + self.fake_quant2 = LearnableFakeQuantize( + observer=MovingAverageMinMaxObserver + ) + + def forward(self, x): + x = self.linear1(x) + x = self.fake_quant1(x) + x = self.linear2(x) + x = self.fake_quant2(x) + return x + + model = MultiQuantModel() + model.fake_quant1.enable_range_learning() + model.fake_quant2.enable_range_learning() + + x = torch.randn(4, 10) + # Run model forward to initialize learnable fake quantizers + output = model(x) + + self.assertEqual(output.shape, (4, 5)) + + # Test backward pass + loss = output.sum() + loss.backward() + + # Check that all gradients exist + self.assertIsNotNone(model.linear1.weight.grad) + self.assertIsNotNone(model.linear2.weight.grad) + self.assertIsNotNone(model.fake_quant1.scale.grad) + self.assertIsNotNone(model.fake_quant1.zero_point.grad) + self.assertIsNotNone(model.fake_quant2.scale.grad) + self.assertIsNotNone(model.fake_quant2.zero_point.grad) + + def test_training_mode_switching(self): + """Test switching between training and evaluation modes.""" + + class TrainableModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 3) + self.fake_quant = LearnableFakeQuantize( + observer=MovingAverageMinMaxObserver + ) + + def forward(self, x): + x = self.linear(x) + x = self.fake_quant(x) + return x + + model = TrainableModel() + x = torch.randn(2, 5) + + # Test in training mode + model.train() + model.fake_quant.enable_range_learning() + # Run model forward to initialize learnable fake quantizers + output_train = model(x) + self.assertEqual(output_train.shape, (2, 3)) + + # Test in evaluation mode + model.eval() + model.fake_quant.enable_observer(True) + output_eval = model(x) + self.assertEqual(output_eval.shape, (2, 3)) + + def test_device_compatibility(self): + """Test LearnableFakeQuantize with different devices.""" + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + + for device in devices: + with self.subTest(device=device): + lfq = LearnableFakeQuantize(observer=MovingAverageMinMaxObserver).to( + device + ) + + x = torch.randn(4, 4, device=device) + output = lfq(x) + + self.assertEqual(output.device, x.device) + self.assertEqual(output.shape, x.shape) + + def test_optimizer_updates_scale_and_zero_point(self): + """Test that optimizer.step() actually updates scale and zero_point parameters.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.fake_quant = LearnableFakeQuantize( + observer=MovingAverageMinMaxObserver + ) + + def forward(self, x): + x = self.linear(x) + x = self.fake_quant(x) + return x + + model = SimpleModel() + model.fake_quant.enable_range_learning() + + x = torch.randn(4, 10) + output = model(x) + + initial_scale = model.fake_quant.scale.data.clone() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + for _ in range(5): + optimizer.zero_grad() + output = model(x) + loss = output.sum() + loss.backward() + optimizer.step() + final_scale = model.fake_quant.scale.data + self.assertFalse( + torch.allclose(initial_scale, final_scale, atol=1e-6), + "Scale should change after optimizer.step()", + ) + + +class TestLearnableFakeQuantizeComparison(TestCase): + """Test cases comparing LearnableFakeQuantize with reference implementations.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + torch.manual_seed(42) + np.random.seed(NP_RANDOM_SEED) + + def test_serialization(self): + """Test serialization and deserialization of LearnableFakeQuantize.""" + observer = MovingAverageMinMaxObserver + quant_min = 0 + quant_max = 127 + + lfq_module = LearnableFakeQuantize(observer, quant_min, quant_max) + X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) + lfq_module(X) # Run forward pass to initialize parameters + + # Get state dict and test serialization + state_dict = lfq_module.state_dict() + self.assertIn("scale", state_dict) + self.assertIn("zero_point", state_dict) + + # Create new module and load state dict + loaded_lfq_module = LearnableFakeQuantize(observer, quant_min, quant_max) + # Initialize parameters first before loading state dict + loaded_lfq_module(X) + loaded_lfq_module.load_state_dict(state_dict) + + # Compare qparams + original_qparams = lfq_module.calculate_qparams() + loaded_qparams = loaded_lfq_module.calculate_qparams() + self.assertEqual(original_qparams[0], loaded_qparams[0]) # scale + self.assertEqual(original_qparams[1], loaded_qparams[1]) # zero_point + + def test_numerical_consistency_per_tensor(self): + """Test numerical consistency of per-tensor quantization.""" + torch_types = [torch.qint8, torch.quint8] + float_types = [torch.float, torch.float16, torch.bfloat16, torch.float64] + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + + for torch_type, float_type, device in itertools.product( + torch_types, float_types, devices + ): + with self.subTest( + torch_type=torch_type, float_type=float_type, device=device + ): + X = torch.randn(3, 3, device=device).to(float_type) + scale = (10 * torch.randn(1, device=device)).abs().item() + zero_point = (10 * torch.randn(1, device=device)).abs().item() + quant_min = torch.iinfo(torch_type).min + quant_max = torch.iinfo(torch_type).max + + # Quantize/dequantize operation + Y = ( + torch.dequantize( + torch.quantize_per_tensor( + X.to("cpu").to(torch.float), + scale, + int(zero_point), + torch_type, + ) + ) + .to(device) + .to(float_type) + ) + + # Fake quantize operation + Y_prime = torch.fake_quantize_per_tensor_affine( + X, scale, int(zero_point), quant_min, quant_max + ) + + torch.testing.assert_close( + Y, + Y_prime, + rtol=tolerance, + atol=tolerance, + msg="Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index c7030023dc..101838c1ae 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -55,6 +55,11 @@ enable_fake_quant, enable_observer, ) +from .learnable_fake_quantize import ( + LearnableFakeQuantize, + disable_range_learning, + enable_range_learning, +) from .observer import ( AffineQuantizedObserverBase, FixedQParamsObserver, @@ -103,6 +108,7 @@ "FusedMovingAvgObsFakeQuantize", # old observers "HistogramObserver", + "LearnableFakeQuantize", "MinMaxObserver", "MovingAverageMinMaxObserver", "MovingAveragePerChannelMinMaxObserver", @@ -121,6 +127,8 @@ "enable_observer", "disable_fake_quant", "disable_observer", + "enable_range_learning", + "disable_range_learning", # export_utils "move_exported_model_to_eval", "move_exported_model_to_train", diff --git a/torchao/quantization/pt2e/learnable_fake_quantize.py b/torchao/quantization/pt2e/learnable_fake_quantize.py new file mode 100644 index 0000000000..43f8698a96 --- /dev/null +++ b/torchao/quantization/pt2e/learnable_fake_quantize.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from torchao.quantization.pt2e.fake_quantize import FakeQuantizeBase + +__all__ = ["LearnableFakeQuantize"] + + +class LearnableFakeQuantize(FakeQuantizeBase): + r"""Generalized extension of the FakeQuantize module. + + This is an extension of the FakeQuantize module, which + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. + + In addition to the attributes in the original FakeQuantize module, the LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + + * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are + normalized by the constant, which is proportional to the square root of the number of + elements in the tensor. The related literature justifying the use of this particular constant + can be found here: https://openreview.net/pdf?id=rkgO66VKDS. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. + """ + + def __init__( + self, + observer, + quant_min=0, + quant_max=255, + use_grad_scaling=False, + **observer_kwargs, + ): + super().__init__() + assert quant_min < quant_max, "quant_min must be strictly less than quant_max." + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + self.use_grad_scaling = use_grad_scaling + + # Initialize scale and zero_point as None, will be initialized during first forward pass + self.scale: Optional[torch.nn.Parameter] = None + self.zero_point: Optional[torch.nn.Parameter] = None + + self.activation_post_process = observer(**observer_kwargs) + assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( + "quant_min out of bound" + ) + assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( + "quant_max out of bound" + ) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) + self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) + + self._initialized = False + + @torch.jit.export + def enable_range_learning(self) -> None: + r"""Enable quantization parameter learning. + + Enables learning of quantization parameters and + disables observer estimates. + """ + self.learning_enabled[0] = 1 + self.disable_observer() + if self.scale is not None: + self.scale.requires_grad = True + if self.zero_point is not None: + self.zero_point.requires_grad = True + + @torch.jit.export + def disable_range_learning(self) -> None: + r"""Disable quantization parameter learning. + + Disables learning of quantization parameters + """ + self.learning_enabled[0] = 0 + if self.scale is not None: + self.scale.requires_grad = False + if self.zero_point is not None: + self.zero_point.requires_grad = False + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + r"""Enable observer. + + Enables observer estimates and disables learning of + quantization parameters. + """ + self.observer_enabled[0] = 1 if enabled else 0 + if enabled: + self.disable_range_learning() + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def observe_quant_params(self): + print(f"LearnableFakeQuantize Scale: {self.scale.detach()}") + print(f"LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") + + @torch.jit.export + def calculate_qparams(self): + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = ( + self.zero_point.detach() + .round() + .clamp(self.quant_min, self.quant_max) + .long() + ) + return scale, zero_point + + @torch.jit.export + def extra_repr(self): + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"learning_enabled={self.learning_enabled}, quant_min={self.quant_min}, quant_max={self.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"use_grad_scaling={self.use_grad_scaling}, scale={self.scale}, zero_point={self.zero_point}" + ) + + def _initialize_or_update_qparams( + self, scale: torch.Tensor, zero_point: torch.Tensor + ) -> None: + """ + Initialize scale and zero_point parameters if they are not initialized yet. + Update them if they are already initialized. + """ + if not self._initialized: + self.scale = Parameter(scale) + # Convert zero_point to float for learnable parameters + self.zero_point = Parameter(zero_point.float()) + # Set requires_grad based on current learning state + if self.learning_enabled[0] == 1: + self.scale.requires_grad = True + self.zero_point.requires_grad = True + else: + self.scale.requires_grad = False + self.zero_point.requires_grad = False + self._initialized = True + else: + self.scale.data.copy_(scale) + self.zero_point.data.copy_(zero_point.float()) + + def forward(self, X): + if self.observer_enabled[0] == 1 or not self._initialized: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + self._initialize_or_update_qparams(_scale, _zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.qscheme in ( + torch.per_channel_symmetric, + torch.per_tensor_symmetric, + ): + self.zero_point.data.zero_() + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.quant_min, + self.quant_max, + grad_factor, + ) + else: + X = torch._fake_quantize_learnable_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.quant_min, + self.quant_max, + grad_factor, + ) + + return X + + +def enable_range_learning(mod): + """Enable quantization parameter learning. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torchao.quantization.pt2e.enable_range_learning) + + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_range_learning() + + +def disable_range_learning(mod): + """Enable quantization parameter learning. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torchao.quantization.pt2e.disable_range_learning) + + """ + if isinstance(mod, LearnableFakeQuantize): + mod.disable_range_learning()