From 605e466bb1d643707eba76078553d88db94c5e9d Mon Sep 17 00:00:00 2001 From: Raghu Krishnamoorthi Date: Sun, 22 Sep 2019 01:46:09 -0700 Subject: [PATCH 1/2] Emulate weight and activation only quant with fake quant, numerics test Differential Revision: [D17520342](https://our.internmc.facebook.com/intern/diff/D17520342/) [ghstack-poisoned] --- test/common_quantization.py | 30 +++++++++++++++++++ test/test_quantized_models.py | 56 ++++++++++++++++++++++++++++++++++- torch/quantization/QConfig.py | 4 +++ 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/test/common_quantization.py b/test/common_quantization.py index fb73b498de08..a4021ac15ab2 100644 --- a/test/common_quantization.py +++ b/test/common_quantization.py @@ -523,3 +523,33 @@ def forward(self, x): out = out.view(-1, 3 * 2 * 2) out = self.fc(out) return out + +class ModelMultipleOpsNoAvgPool(torch.nn.Module): + def __init__(self): + super(ModelMultipleOpsNoAvgPool, self).__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.maxpool = nn.MaxPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + skip = self.conv2(x) + out = self.skip_add.add(out, skip) + out = self.relu2(out) + out = self.maxpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.view(-1, 3 * 2 * 2) + out = self.fc(out) + return out diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py index 5360fd3eee28..b9dca9d905d2 100644 --- a/test/test_quantized_models.py +++ b/test/test_quantized_models.py @@ -1,7 +1,7 @@ import torch import torch.jit from common_utils import run_tests -from common_quantization import QuantizationTestCase, ModelMultipleOps +from common_quantization import QuantizationTestCase, ModelMultipleOps, ModelMultipleOpsNoAvgPool class ModelNumerics(QuantizationTestCase): def test_float_quant_compare_per_tensor(self): @@ -46,5 +46,59 @@ def test_float_quant_compare_per_channel(self): # Setting target SQNR to be 35 dB self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') + def test_fake_quant_true_quant_compare(self): + torch.manual_seed(67) + myModel = ModelMultipleOpsNoAvgPool().to(torch.float32) + calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) + eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) + myModel.eval() + out_ref = myModel(eval_data) + fqModel = torch.quantization.QuantWrapper(myModel) + fqModel.train() + fqModel.qconfig = torch.quantization.default_qat_qconfig + torch.quantization.fuse_modules(fqModel.module, [['conv1', 'bn1', 'relu1']]) + torch.quantization.prepare_qat(fqModel) + fqModel.eval() + fqModel.apply(torch.quantization.disable_fake_quant) + fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats) + fqModel(calib_data) + fqModel.apply(torch.quantization.enable_fake_quant) + fqModel.apply(torch.quantization.disable_observer) + out_fq = fqModel(eval_data) + SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) + # Quantized model output should be close to floating point model output numerically + # Setting target SQNR to be 35 dB + self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') + torch.quantization.convert(fqModel) + out_q = fqModel(eval_data) + SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) + self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB') + + def test_weight_only_activation_only_fakequant(self): + torch.manual_seed(67) + calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) + eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) + qconfigset = set([torch.quantization.default_weight_only_quant_qconfig, + torch.quantization.default_activation_only_quant_qconfig]) + SQNRTarget = [35, 45] + for idx, qconfig in enumerate(qconfigset): + myModel = ModelMultipleOpsNoAvgPool().to(torch.float32) + myModel.eval() + out_ref = myModel(eval_data) + fqModel = torch.quantization.QuantWrapper(myModel) + fqModel.train() + fqModel.qconfig = qconfig + torch.quantization.fuse_modules(fqModel.module, [['conv1', 'bn1', 'relu1']]) + torch.quantization.prepare_qat(fqModel) + fqModel.eval() + fqModel.apply(torch.quantization.disable_fake_quant) + fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats) + fqModel(calib_data) + fqModel.apply(torch.quantization.enable_fake_quant) + fqModel.apply(torch.quantization.disable_observer) + out_fq = fqModel(eval_data) + SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) + self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float') + if __name__ == "__main__": run_tests() diff --git a/torch/quantization/QConfig.py b/torch/quantization/QConfig.py index b427a6756b7a..0cbac4349608 100644 --- a/torch/quantization/QConfig.py +++ b/torch/quantization/QConfig.py @@ -19,3 +19,7 @@ default_qat_qconfig = QConfig(activation=default_fake_quant(), weight=default_weight_fake_quant()) +# Configs for simulating weight only quantization for debugging +# Use only to analyze accuracy tradeoffs +default_weight_only_quant_qconfig = QConfig(activation=observer(torch.nn.Identity), weight=default_weight_fake_quant()) +default_activation_only_quant_qconfig = QConfig(activation=default_fake_quant(), weight=observer(torch.nn.Identity)) From 08f894530bf7b12f5dd3193ebd52fd5580eed321 Mon Sep 17 00:00:00 2001 From: Raghu Krishnamoorthi Date: Fri, 27 Sep 2019 19:31:33 -0700 Subject: [PATCH 2/2] Update on "Emulate weight and activation only quant with fake quant, numerics test" Differential Revision: [D17520342](https://our.internmc.facebook.com/intern/diff/D17520342/) [ghstack-poisoned] --- test/common_quantization.py | 4 ++++ test/test_quantized_models.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/test/common_quantization.py b/test/common_quantization.py index f410b97007b6..838d6e1295ca 100644 --- a/test/common_quantization.py +++ b/test/common_quantization.py @@ -524,6 +524,10 @@ def forward(self, x): out = self.fc(out) return out +# Model to ensure consistency of fake quant with true quant +# Average pooling and mean operations are not modelled +# accurately with fake-quant so this model does not +# contain those operations class ModelMultipleOpsNoAvgPool(torch.nn.Module): def __init__(self): super(ModelMultipleOpsNoAvgPool, self).__init__() diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py index 93b6437662ce..5d1cb2ddc990 100644 --- a/test/test_quantized_models.py +++ b/test/test_quantized_models.py @@ -78,6 +78,8 @@ def test_fake_quant_true_quant_compare(self): SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB') + # Test to compare weight only quantized model numerics and + # activation only quantized model numerics with float def test_weight_only_activation_only_fakequant(self): torch.manual_seed(67) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)