From f9f7eb4a3e32198c81ced42baf7f40298a5e4cad Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 9 Oct 2025 20:44:49 -0400 Subject: [PATCH] Revert "add learnable_fake_quantize in pt2e" This reverts commit 233cfc1fd345881c7bcbb5de838d82d76fe8be8e. --- .../pt2e/test_learnable_fake_quantize.py | 804 ------------------ torchao/quantization/pt2e/__init__.py | 8 - .../pt2e/learnable_fake_quantize.py | 233 ----- 3 files changed, 1045 deletions(-) delete mode 100644 test/quantization/pt2e/test_learnable_fake_quantize.py delete mode 100644 torchao/quantization/pt2e/learnable_fake_quantize.py diff --git a/test/quantization/pt2e/test_learnable_fake_quantize.py b/test/quantization/pt2e/test_learnable_fake_quantize.py deleted file mode 100644 index ede8ffa935..0000000000 --- a/test/quantization/pt2e/test_learnable_fake_quantize.py +++ /dev/null @@ -1,804 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -import unittest - -import numpy as np -import torch -import torch.nn as nn -from torch.testing._internal.common_utils import TestCase - -from torchao.quantization.pt2e.learnable_fake_quantize import ( - LearnableFakeQuantize, -) -from torchao.quantization.pt2e.observer import ( - MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, -) - - -# Reference methods for fake quantize operations -def _fake_quantize_per_tensor_affine_reference( - X, scale, zero_point, quant_min, quant_max -): - """Reference implementation of per-tensor fake quantization.""" - dtype = X.dtype - res = ( - torch.clamp( - torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point), - quant_min, - quant_max, - ) - - zero_point - ) * scale - return res.to(dtype) - - -def _fake_quantize_per_tensor_affine_grad_reference( - dY, X, scale, zero_point, quant_min, quant_max -): - """Reference implementation of per-tensor fake quantization gradient.""" - dtype = X.dtype - Xq = torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point) - mask = (Xq >= quant_min) * (Xq <= quant_max) - res = torch.zeros_like(dY) - res[mask] = dY[mask] - return res.to(dtype) - - -def _fake_quantize_learnable_per_tensor_affine_grad_reference( - dY, X, scale, zero_point, quant_min, quant_max, device -): - """Reference implementation of learnable per-tensor fake quantization gradients.""" - zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item()) - Xq = torch.round(X * (1.0 / scale) + zero_point_rounded) - - indicate_small_scale = (Xq < quant_min).float().to(device) - indicate_big_scale = (Xq > quant_max).float().to(device) - indicate_middle_scale = ( - torch.ones(indicate_small_scale.shape).to(device) - - indicate_small_scale - - indicate_big_scale - ) - - indicate_saturate_zp = ((Xq < quant_min).float() + (Xq > quant_max).float()).to( - device - ) - indicate_unsaturate_zp = ( - torch.ones(indicate_saturate_zp.shape).to(device) - indicate_saturate_zp - ) - - Xq = Xq.clamp(quant_min, quant_max) - Xfq = (Xq - zero_point_rounded) * scale - - grad_small_scale = quant_min - zero_point_rounded - grad_big_scale = quant_max - zero_point_rounded - grad_middle_scale = ((Xfq - X) / scale).to(device) - - grad_saturate_zp = -scale.to(device) - grad_unsaturate_zp = 0 - - grad_scale = ( - indicate_small_scale * grad_small_scale - + indicate_big_scale * grad_big_scale - + indicate_middle_scale * grad_middle_scale - ) - grad_zp = ( - indicate_saturate_zp * grad_saturate_zp - + indicate_unsaturate_zp * grad_unsaturate_zp - ) - grad_X = _fake_quantize_per_tensor_affine_grad_reference( - dY, X, scale, zero_point, quant_min, quant_max - ).to(device) - - grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0) - grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0) - return grad_X, grad_scale, grad_zp - - -# Removed unused helper functions _get_tensor_min_max and _get_scale_zp -# These were not being used in any of the tests - - -NP_RANDOM_SEED = 19 -tolerance = 1e-6 - - -class TestLearnableFakeQuantize(TestCase): - """Test cases for LearnableFakeQuantize module.""" - - def setUp(self): - """Set up test fixtures before each test method.""" - torch.manual_seed(42) - np.random.seed(NP_RANDOM_SEED) - - def test_initialization_per_tensor(self): - """Test initialization of LearnableFakeQuantize module for per-tensor quantization.""" - observer = MovingAverageMinMaxObserver - quant_min = 0 - quant_max = 255 - - lfq = LearnableFakeQuantize( - observer=observer, quant_min=quant_min, quant_max=quant_max - ) - - # Test that the module is properly initialized - self.assertEqual(lfq.quant_min, quant_min) - self.assertEqual(lfq.quant_max, quant_max) - # Initially scale and zero_point should be None - self.assertIsNone(lfq.scale) - self.assertIsNone(lfq.zero_point) - self.assertFalse(lfq._initialized) - - def test_initialization_per_channel(self): - """Test initialization of LearnableFakeQuantize module for per-channel quantization.""" - observer = MovingAveragePerChannelMinMaxObserver - quant_min = 0 - quant_max = 255 - - lfq = LearnableFakeQuantize( - observer=observer, quant_min=quant_min, quant_max=quant_max, ch_axis=0 - ) - - # Test that the module is properly initialized for per-channel - self.assertEqual(lfq.quant_min, quant_min) - self.assertEqual(lfq.quant_max, quant_max) - # Initially scale and zero_point should be None - self.assertIsNone(lfq.scale) - self.assertIsNone(lfq.zero_point) - self.assertFalse(lfq._initialized) - - def test_enable_range_learning(self): - """Test enabling range learning functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Initially learning should be disabled and scale/zero_point should be None - self.assertEqual(lfq.learning_enabled[0], 0) - self.assertIsNone(lfq.scale) - self.assertIsNone(lfq.zero_point) - - # Enable range learning - lfq.enable_range_learning() - - # Check that learning is enabled - self.assertEqual(lfq.learning_enabled[0], 1) - # scale and zero_point are still None until first forward pass - self.assertIsNone(lfq.scale) - self.assertIsNone(lfq.zero_point) - self.assertEqual(lfq.fake_quant_enabled[0], 1) - self.assertEqual(lfq.observer_enabled[0], 0) - - def test_disable_range_learning(self): - """Test disabling range learning functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Enable range learning first - lfq.enable_range_learning() - - # Run a forward pass to initialize scale and zero_point - X = torch.randn(4, 4) - lfq(X) - - # Then disable range learning - lfq.disable_range_learning() - - # Check that learning is disabled - self.assertEqual(lfq.learning_enabled[0], 0) - self.assertFalse(lfq.scale.requires_grad) - self.assertFalse(lfq.zero_point.requires_grad) - - def test_enable_observer(self): - """Test enabling observer functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Enable observer - lfq.enable_observer(True) - - # Check that observer is enabled and learning is disabled - self.assertEqual(lfq.observer_enabled[0], 1) - self.assertEqual(lfq.learning_enabled[0], 0) - - # Test disable observer - lfq.disable_observer() - self.assertEqual(lfq.observer_enabled[0], 0) - - def test_fake_quant_control(self): - """Test fake quantization control functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Test enable_fake_quant - lfq.enable_fake_quant(True) - self.assertEqual(lfq.fake_quant_enabled[0], 1) - - # Test disable_fake_quant - lfq.disable_fake_quant() - self.assertEqual(lfq.fake_quant_enabled[0], 0) - - def test_calculate_qparams(self): - """Test calculation of quantization parameters.""" - observer = MovingAverageMinMaxObserver - scale_val = 0.1 - zero_point_val = 128.0 - quant_min = 0 - quant_max = 255 - - lfq = LearnableFakeQuantize( - observer=observer, quant_min=quant_min, quant_max=quant_max - ) - - # Initialize parameters by running a forward pass first - X = torch.randn(4, 4) - lfq(X) - - # Manually set the scale and zero_point values for testing - lfq.scale.data.fill_(scale_val) - lfq.zero_point.data.fill_(zero_point_val) - - scale, zero_point = lfq.calculate_qparams() - - # Check that scale is properly clamped and zero_point is properly rounded/clamped - self.assertGreaterEqual(scale.item(), lfq.eps.item()) - self.assertGreaterEqual(zero_point.item(), quant_min) - self.assertLessEqual(zero_point.item(), quant_max) - self.assertEqual(zero_point.dtype, torch.long) - - def test_forward_observer_enabled(self): - """Test forward pass with observer enabled.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Enable observer - lfq.enable_observer(True) - - # Create test input - X = torch.randn(4, 4) * 10 - - # Forward pass - output = lfq(X) - - # Check that output has correct shape and type - self.assertEqual(output.shape, X.shape) - self.assertEqual(output.dtype, X.dtype) - - # Check that scale and zero_point have been initialized - self.assertIsNotNone(lfq.scale) - self.assertIsNotNone(lfq.zero_point) - - def test_forward_learning_enabled(self): - """Test forward pass with range learning enabled.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Enable range learning - lfq.enable_range_learning() - - # Create test input that requires grad - X = torch.randn(4, 4, requires_grad=True) * 10 - - # Run forward pass to initialize learnable fake quantizers - output = lfq(X) - - # Check that output has correct shape and type - self.assertEqual(output.shape, X.shape) - self.assertEqual(output.dtype, X.dtype) - - # Check that gradients can flow through - loss = output.sum() - loss.backward() - # Note: X may not have grad if it's not a leaf tensor; focus on testing quantizer gradients - self.assertIsNotNone(lfq.scale.grad) - self.assertIsNotNone(lfq.zero_point.grad) - - def test_forward_fake_quant_disabled(self): - """Test forward pass with fake quantization disabled.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Disable fake quantization - lfq.disable_fake_quant() - - # Create test input - X = torch.randn(4, 4) * 10 - - # Forward pass - output = lfq(X) - - # Output should be identical to input when fake quantization is disabled - torch.testing.assert_close(output, X) - - def test_symmetric_quantization(self): - """Test symmetric quantization behavior.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Enable fake quantization - lfq.enable_fake_quant(True) - - # Create test input - X = torch.randn(4, 4) * 10 - - # Forward pass to initialize parameters - lfq(X) - - # For symmetric quantization, zero_point should be zero - # (Note: This test assumes the implementation handles symmetric mode) - self.assertIsNotNone(lfq.zero_point) - - def test_per_channel_quantization(self): - """Test per-channel quantization functionality.""" - observer = MovingAveragePerChannelMinMaxObserver - channel_len = 4 - - lfq = LearnableFakeQuantize(observer=observer, ch_axis=0) - - # Enable fake quantization - lfq.enable_fake_quant(True) - - # Create test input with correct channel dimension - X = torch.randn(channel_len, 8) * 10 - - # Forward pass - output = lfq(X) - - # Check that output has correct shape - self.assertEqual(output.shape, X.shape) - self.assertEqual(lfq.scale.shape[0], channel_len) - self.assertEqual(lfq.zero_point.shape[0], channel_len) - - def test_gradient_scaling(self): - """Test gradient scaling functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer, use_grad_scaling=True) - - # Enable range learning - lfq.enable_range_learning() - - # Create test input that requires grad - X = torch.randn(4, 4, requires_grad=True) * 10 - - # Run forward pass to initialize learnable fake quantizers - output = lfq(X) - - # Check that gradients can flow through - loss = output.sum() - loss.backward() - self.assertIsNotNone(lfq.scale.grad) - self.assertIsNotNone(lfq.zero_point.grad) - - def test_error_conditions(self): - """Test error conditions during initialization.""" - observer = MovingAverageMinMaxObserver - - # Test quant_min >= quant_max - with self.assertRaises(AssertionError): - LearnableFakeQuantize(observer=observer, quant_min=255, quant_max=0) - - def test_state_persistence(self): - """Test that module state is properly maintained across forward passes.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer) - - # Initial state - initial_fake_quant_enabled = lfq.fake_quant_enabled[0].item() - initial_observer_enabled = lfq.observer_enabled[0].item() - initial_learning_enabled = lfq.learning_enabled[0].item() - - # Forward pass - X = torch.randn(4, 4) - lfq(X) # We don't need to store the output, just run the forward pass - - # State should be preserved - self.assertEqual(lfq.fake_quant_enabled[0].item(), initial_fake_quant_enabled) - self.assertEqual(lfq.observer_enabled[0].item(), initial_observer_enabled) - self.assertEqual(lfq.learning_enabled[0].item(), initial_learning_enabled) - - def test_learnable_forward_per_tensor(self): - """Test learnable forward pass for per-tensor quantization.""" - X = torch.randn(5, 5, dtype=torch.float32) - scale_base = torch.tensor([0.1], dtype=torch.float32) - zero_point_base = torch.tensor([128.0], dtype=torch.float32) - - for n_bits in (4, 8): - quant_min, quant_max = 0, 2**n_bits - 1 - - X_test = X.clone().float() - scale = scale_base.clone() - zero_point = zero_point_base.clamp(quant_min, quant_max) - - Y = _fake_quantize_per_tensor_affine_reference( - X_test, scale, zero_point, quant_min, quant_max - ) - - for grad_factor in [0.1, 1.0, 10.0]: - Y_prime = torch._fake_quantize_learnable_per_tensor_affine( - X_test, scale, zero_point, quant_min, quant_max, grad_factor - ) - self.assertTrue( - torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), - "Expected kernel forward function to have results match the reference forward function", - ) - - def test_learnable_backward_per_tensor(self): - """Test learnable backward pass for per-tensor quantization.""" - X = torch.randn(5, 5, dtype=torch.float32) - scale_base = torch.tensor([0.1], dtype=torch.float32) - zero_point_base = torch.tensor([128.0], dtype=torch.float32) - device = "cpu" - - for n_bits in (4, 8): - quant_min, quant_max = 0, 2**n_bits - 1 - - X_test = X.clone().float() - X_test.requires_grad_() - scale = scale_base.clone() - scale.requires_grad_() - zero_point = zero_point_base.clone().clamp(quant_min, quant_max) - zero_point.requires_grad_() - - for grad_factor in [0.1, 1.0, 10.0]: - Y_prime = torch._fake_quantize_learnable_per_tensor_affine( - X_test, scale, zero_point, quant_min, quant_max, grad_factor - ) - dout = torch.rand_like(X_test, dtype=torch.float) - dX, dScale, dZeroPoint = ( - _fake_quantize_learnable_per_tensor_affine_grad_reference( - dout, X_test, scale, zero_point, quant_min, quant_max, device - ) - ) - Y_prime.backward(dout) - - expected_dX = dX.detach() - actual_dX = X_test.grad.detach() - expected_dScale = dScale.detach() - actual_dScale = scale.grad.detach() - expected_dZeroPoint = dZeroPoint.detach() - actual_dZeroPoint = zero_point.grad.detach() - - self.assertTrue( - torch.allclose( - expected_dX, actual_dX, rtol=tolerance, atol=tolerance - ), - "Expected dX to match X.grad", - ) - self.assertTrue( - torch.allclose( - expected_dScale * grad_factor, - actual_dScale, - rtol=tolerance, - atol=tolerance, - ), - "Expected dScale to match scale.grad", - ) - self.assertTrue( - torch.allclose( - expected_dZeroPoint * grad_factor, - actual_dZeroPoint, - rtol=tolerance, - atol=tolerance, - ), - "Expected dZeroPoint to match zero_point.grad", - ) - X_test.grad.data.zero_() - scale.grad.data.zero_() - zero_point.grad.data.zero_() - - def test_fake_quant_and_observer_control(self): - """Test fake quantization and observer control functionality.""" - observer = MovingAverageMinMaxObserver - lfq = LearnableFakeQuantize(observer=observer, quant_min=0, quant_max=255) - - torch.manual_seed(42) - X = torch.rand(20, 10, dtype=torch.float32) - - # Output of fake quant should not be identical to input initially - Y = lfq(X) - # Note: Initially output might be close to input if scale is 1.0 and zero_point is 0.0 - # Let's just check the shape and type are correct - self.assertEqual(Y.shape, X.shape) - self.assertEqual(Y.dtype, X.dtype) - - # Disable fake quantization - lfq.disable_fake_quant() - X = torch.rand(20, 10, dtype=torch.float32) - Y = lfq(X) - # Fake quant is disabled, output should be identical to input - torch.testing.assert_close(Y, X) - - # Disable observer and enable fake quant - lfq.disable_observer() - lfq.enable_fake_quant(True) - - # Store current scale and zero_point - scale = lfq.scale.detach().clone() - zero_point = lfq.zero_point.detach().clone() - - X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 - Y = lfq(X) - self.assertNotEqual(Y.shape, torch.Size([0])) # Output should exist - # Observer is disabled, scale and zero-point should not change - torch.testing.assert_close(lfq.scale, scale) - torch.testing.assert_close(lfq.zero_point, zero_point) - - # Enable observer - lfq.enable_observer(True) - Y = lfq(X) - self.assertNotEqual(Y.shape, torch.Size([0])) # Output should exist - # Observer is enabled, scale and zero-point may be different - # (though they might not change significantly with this data) - - -class TestLearnableFakeQuantizeIntegration(TestCase): - """Integration tests for LearnableFakeQuantize with neural network modules.""" - - def setUp(self): - """Set up test fixtures before each test method.""" - torch.manual_seed(42) - - def test_integration_with_linear_layer(self): - """Test LearnableFakeQuantize integration with linear layer.""" - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(10, 5) - self.fake_quant = LearnableFakeQuantize( - observer=MovingAverageMinMaxObserver - ) - - def forward(self, x): - x = self.linear(x) - x = self.fake_quant(x) - return x - - model = SimpleModel() - model.fake_quant.enable_range_learning() - - x = torch.randn(4, 10) - # Run model forward to initialize learnable fake quantizers - output = model(x) - - self.assertEqual(output.shape, (4, 5)) - - # Test backward pass - loss = output.sum() - loss.backward() - - # Check that all gradients exist - self.assertIsNotNone(model.linear.weight.grad) - self.assertIsNotNone(model.fake_quant.scale.grad) - self.assertIsNotNone(model.fake_quant.zero_point.grad) - - def test_multiple_fake_quant_modules(self): - """Test multiple LearnableFakeQuantize modules in one model.""" - - class MultiQuantModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(10, 8) - self.fake_quant1 = LearnableFakeQuantize( - observer=MovingAverageMinMaxObserver - ) - self.linear2 = nn.Linear(8, 5) - self.fake_quant2 = LearnableFakeQuantize( - observer=MovingAverageMinMaxObserver - ) - - def forward(self, x): - x = self.linear1(x) - x = self.fake_quant1(x) - x = self.linear2(x) - x = self.fake_quant2(x) - return x - - model = MultiQuantModel() - model.fake_quant1.enable_range_learning() - model.fake_quant2.enable_range_learning() - - x = torch.randn(4, 10) - # Run model forward to initialize learnable fake quantizers - output = model(x) - - self.assertEqual(output.shape, (4, 5)) - - # Test backward pass - loss = output.sum() - loss.backward() - - # Check that all gradients exist - self.assertIsNotNone(model.linear1.weight.grad) - self.assertIsNotNone(model.linear2.weight.grad) - self.assertIsNotNone(model.fake_quant1.scale.grad) - self.assertIsNotNone(model.fake_quant1.zero_point.grad) - self.assertIsNotNone(model.fake_quant2.scale.grad) - self.assertIsNotNone(model.fake_quant2.zero_point.grad) - - def test_training_mode_switching(self): - """Test switching between training and evaluation modes.""" - - class TrainableModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(5, 3) - self.fake_quant = LearnableFakeQuantize( - observer=MovingAverageMinMaxObserver - ) - - def forward(self, x): - x = self.linear(x) - x = self.fake_quant(x) - return x - - model = TrainableModel() - x = torch.randn(2, 5) - - # Test in training mode - model.train() - model.fake_quant.enable_range_learning() - # Run model forward to initialize learnable fake quantizers - output_train = model(x) - self.assertEqual(output_train.shape, (2, 3)) - - # Test in evaluation mode - model.eval() - model.fake_quant.enable_observer(True) - output_eval = model(x) - self.assertEqual(output_eval.shape, (2, 3)) - - def test_device_compatibility(self): - """Test LearnableFakeQuantize with different devices.""" - devices = ["cpu"] - if torch.cuda.is_available(): - devices.append("cuda") - - for device in devices: - with self.subTest(device=device): - lfq = LearnableFakeQuantize(observer=MovingAverageMinMaxObserver).to( - device - ) - - x = torch.randn(4, 4, device=device) - output = lfq(x) - - self.assertEqual(output.device, x.device) - self.assertEqual(output.shape, x.shape) - - def test_optimizer_updates_scale_and_zero_point(self): - """Test that optimizer.step() actually updates scale and zero_point parameters.""" - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(10, 5) - self.fake_quant = LearnableFakeQuantize( - observer=MovingAverageMinMaxObserver - ) - - def forward(self, x): - x = self.linear(x) - x = self.fake_quant(x) - return x - - model = SimpleModel() - model.fake_quant.enable_range_learning() - - x = torch.randn(4, 10) - output = model(x) - - initial_scale = model.fake_quant.scale.data.clone() - initial_zero_point = model.fake_quant.zero_point.data.clone() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - - for _ in range(5): - optimizer.zero_grad() - output = model(x) - loss = output.sum() - loss.backward() - optimizer.step() - - final_scale = model.fake_quant.scale.data - final_zero_point = model.fake_quant.zero_point.data - - self.assertFalse( - torch.allclose(initial_scale, final_scale, atol=1e-6), - "Scale should change after optimizer.step()", - ) - self.assertFalse( - torch.allclose(initial_zero_point, final_zero_point, atol=1e-6), - "Zero point should change after optimizer.step()", - ) - - -class TestLearnableFakeQuantizeComparison(TestCase): - """Test cases comparing LearnableFakeQuantize with reference implementations.""" - - def setUp(self): - """Set up test fixtures before each test method.""" - torch.manual_seed(42) - np.random.seed(NP_RANDOM_SEED) - - def test_serialization(self): - """Test serialization and deserialization of LearnableFakeQuantize.""" - observer = MovingAverageMinMaxObserver - quant_min = 0 - quant_max = 127 - - lfq_module = LearnableFakeQuantize(observer, quant_min, quant_max) - X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) - lfq_module(X) # Run forward pass to initialize parameters - - # Get state dict and test serialization - state_dict = lfq_module.state_dict() - self.assertIn("scale", state_dict) - self.assertIn("zero_point", state_dict) - - # Create new module and load state dict - loaded_lfq_module = LearnableFakeQuantize(observer, quant_min, quant_max) - # Initialize parameters first before loading state dict - loaded_lfq_module(X) - loaded_lfq_module.load_state_dict(state_dict) - - # Compare qparams - original_qparams = lfq_module.calculate_qparams() - loaded_qparams = loaded_lfq_module.calculate_qparams() - self.assertEqual(original_qparams[0], loaded_qparams[0]) # scale - self.assertEqual(original_qparams[1], loaded_qparams[1]) # zero_point - - def test_numerical_consistency_per_tensor(self): - """Test numerical consistency of per-tensor quantization.""" - torch_types = [torch.qint8, torch.quint8] - float_types = [torch.float, torch.float16, torch.bfloat16, torch.float64] - devices = [torch.device("cpu")] - if torch.cuda.is_available(): - devices.append(torch.device("cuda")) - - for torch_type, float_type, device in itertools.product( - torch_types, float_types, devices - ): - with self.subTest( - torch_type=torch_type, float_type=float_type, device=device - ): - X = torch.randn(3, 3, device=device).to(float_type) - scale = (10 * torch.randn(1, device=device)).abs().item() - zero_point = (10 * torch.randn(1, device=device)).abs().item() - quant_min = torch.iinfo(torch_type).min - quant_max = torch.iinfo(torch_type).max - - # Quantize/dequantize operation - Y = ( - torch.dequantize( - torch.quantize_per_tensor( - X.to("cpu").to(torch.float), - scale, - int(zero_point), - torch_type, - ) - ) - .to(device) - .to(float_type) - ) - - # Fake quantize operation - Y_prime = torch.fake_quantize_per_tensor_affine( - X, scale, int(zero_point), quant_min, quant_max - ) - - torch.testing.assert_close( - Y, - Y_prime, - rtol=tolerance, - atol=tolerance, - msg="Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index 101838c1ae..c7030023dc 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -55,11 +55,6 @@ enable_fake_quant, enable_observer, ) -from .learnable_fake_quantize import ( - LearnableFakeQuantize, - disable_range_learning, - enable_range_learning, -) from .observer import ( AffineQuantizedObserverBase, FixedQParamsObserver, @@ -108,7 +103,6 @@ "FusedMovingAvgObsFakeQuantize", # old observers "HistogramObserver", - "LearnableFakeQuantize", "MinMaxObserver", "MovingAverageMinMaxObserver", "MovingAveragePerChannelMinMaxObserver", @@ -127,8 +121,6 @@ "enable_observer", "disable_fake_quant", "disable_observer", - "enable_range_learning", - "disable_range_learning", # export_utils "move_exported_model_to_eval", "move_exported_model_to_train", diff --git a/torchao/quantization/pt2e/learnable_fake_quantize.py b/torchao/quantization/pt2e/learnable_fake_quantize.py deleted file mode 100644 index 43f8698a96..0000000000 --- a/torchao/quantization/pt2e/learnable_fake_quantize.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -from typing import Optional - -import torch -from torch.nn.parameter import Parameter - -from torchao.quantization.pt2e.fake_quantize import FakeQuantizeBase - -__all__ = ["LearnableFakeQuantize"] - - -class LearnableFakeQuantize(FakeQuantizeBase): - r"""Generalized extension of the FakeQuantize module. - - This is an extension of the FakeQuantize module, which - supports more generalized lower-bit quantization and supports learning of the scale - and zero point parameters through backpropagation. - - In addition to the attributes in the original FakeQuantize module, the LearnableFakeQuantize - module also includes the following attributes to support quantization parameter learning. - - * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are - normalized by the constant, which is proportional to the square root of the number of - elements in the tensor. The related literature justifying the use of this particular constant - can be found here: https://openreview.net/pdf?id=rkgO66VKDS. - - * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. - """ - - def __init__( - self, - observer, - quant_min=0, - quant_max=255, - use_grad_scaling=False, - **observer_kwargs, - ): - super().__init__() - assert quant_min < quant_max, "quant_min must be strictly less than quant_max." - self.quant_min = quant_min - self.quant_max = quant_max - # also pass quant_min and quant_max to observer - observer_kwargs["quant_min"] = quant_min - observer_kwargs["quant_max"] = quant_max - self.use_grad_scaling = use_grad_scaling - - # Initialize scale and zero_point as None, will be initialized during first forward pass - self.scale: Optional[torch.nn.Parameter] = None - self.zero_point: Optional[torch.nn.Parameter] = None - - self.activation_post_process = observer(**observer_kwargs) - assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( - "quant_min out of bound" - ) - assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( - "quant_max out of bound" - ) - self.dtype = self.activation_post_process.dtype - self.qscheme = self.activation_post_process.qscheme - self.ch_axis = ( - self.activation_post_process.ch_axis - if hasattr(self.activation_post_process, "ch_axis") - else -1 - ) - self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) - self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) - - self._initialized = False - - @torch.jit.export - def enable_range_learning(self) -> None: - r"""Enable quantization parameter learning. - - Enables learning of quantization parameters and - disables observer estimates. - """ - self.learning_enabled[0] = 1 - self.disable_observer() - if self.scale is not None: - self.scale.requires_grad = True - if self.zero_point is not None: - self.zero_point.requires_grad = True - - @torch.jit.export - def disable_range_learning(self) -> None: - r"""Disable quantization parameter learning. - - Disables learning of quantization parameters - """ - self.learning_enabled[0] = 0 - if self.scale is not None: - self.scale.requires_grad = False - if self.zero_point is not None: - self.zero_point.requires_grad = False - - @torch.jit.export - def enable_observer(self, enabled: bool = True) -> None: - r"""Enable observer. - - Enables observer estimates and disables learning of - quantization parameters. - """ - self.observer_enabled[0] = 1 if enabled else 0 - if enabled: - self.disable_range_learning() - - @torch.jit.export - def disable_observer(self): - self.enable_observer(False) - - @torch.jit.export - def enable_fake_quant(self, enabled: bool = True) -> None: - self.fake_quant_enabled[0] = 1 if enabled else 0 - - @torch.jit.export - def disable_fake_quant(self): - self.enable_fake_quant(False) - - @torch.jit.export - def observe_quant_params(self): - print(f"LearnableFakeQuantize Scale: {self.scale.detach()}") - print(f"LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") - - @torch.jit.export - def calculate_qparams(self): - self.scale.data.clamp_(min=self.eps.item()) - scale = self.scale.detach() - zero_point = ( - self.zero_point.detach() - .round() - .clamp(self.quant_min, self.quant_max) - .long() - ) - return scale, zero_point - - @torch.jit.export - def extra_repr(self): - return ( - f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " - f"learning_enabled={self.learning_enabled}, quant_min={self.quant_min}, quant_max={self.quant_max}, " - f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " - f"use_grad_scaling={self.use_grad_scaling}, scale={self.scale}, zero_point={self.zero_point}" - ) - - def _initialize_or_update_qparams( - self, scale: torch.Tensor, zero_point: torch.Tensor - ) -> None: - """ - Initialize scale and zero_point parameters if they are not initialized yet. - Update them if they are already initialized. - """ - if not self._initialized: - self.scale = Parameter(scale) - # Convert zero_point to float for learnable parameters - self.zero_point = Parameter(zero_point.float()) - # Set requires_grad based on current learning state - if self.learning_enabled[0] == 1: - self.scale.requires_grad = True - self.zero_point.requires_grad = True - else: - self.scale.requires_grad = False - self.zero_point.requires_grad = False - self._initialized = True - else: - self.scale.data.copy_(scale) - self.zero_point.data.copy_(zero_point.float()) - - def forward(self, X): - if self.observer_enabled[0] == 1 or not self._initialized: - self.activation_post_process(X.detach()) - _scale, _zero_point = self.activation_post_process.calculate_qparams() - self._initialize_or_update_qparams(_scale, _zero_point) - - if self.fake_quant_enabled[0] == 1: - if self.qscheme in ( - torch.per_channel_symmetric, - torch.per_tensor_symmetric, - ): - self.zero_point.data.zero_() - - if self.use_grad_scaling: - grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 - else: - grad_factor = 1.0 - if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): - X = torch._fake_quantize_learnable_per_channel_affine( - X, - self.scale, - self.zero_point, - self.ch_axis, - self.quant_min, - self.quant_max, - grad_factor, - ) - else: - X = torch._fake_quantize_learnable_per_tensor_affine( - X, - self.scale, - self.zero_point, - self.quant_min, - self.quant_max, - grad_factor, - ) - - return X - - -def enable_range_learning(mod): - """Enable quantization parameter learning. - - Enable fake quantization for this module, if applicable. Example usage:: - - # model is any PyTorch model - model.apply(torchao.quantization.pt2e.enable_range_learning) - - """ - if isinstance(mod, LearnableFakeQuantize): - mod.enable_range_learning() - - -def disable_range_learning(mod): - """Enable quantization parameter learning. - - Enable fake quantization for this module, if applicable. Example usage:: - - # model is any PyTorch model - model.apply(torchao.quantization.pt2e.disable_range_learning) - - """ - if isinstance(mod, LearnableFakeQuantize): - mod.disable_range_learning()