From 517e898df77656662e59291ee3c156ce65f93431 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 6 Aug 2020 11:41:40 -0700 Subject: [PATCH 1/3] [quant] Make PerChannel Observer work with float qparams Summary: Add implementation for new qscheme per_channel_affine_float_qparams in observer Test Plan: python test/test_quantization.py TestObserver.test_per_channel_observers Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/quantization/test_workflow_module.py | 19 ++++++++++++++++--- torch/quantization/observer.py | 16 +++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index e4177ced4d6c..1642ecc4eea4 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -281,11 +281,14 @@ def test_per_tensor_dynamic_quant_observers(self, X, reduce_range): self.assertEqual(ref[0], qparams[0]) self.assertEqual(ref[1], qparams[1]) + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), - qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)), + qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)), ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 + if qscheme == torch.per_channel_affine_float_qparams: + reduce_range = False if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: reduce_range = False ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range, @@ -338,6 +341,12 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): [-26, -128], [-35, -58], ] + per_channel_affine_float_qparams_ref_scales = [ + [0.0196, 0.0471], + [0.0353, 0.0196], + [0.0392, 0.0235], + [0.0431, 0.0431], + ] per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) @@ -345,6 +354,9 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): if qscheme == torch.per_channel_symmetric: ref_scales = per_channel_symmetric_ref_scales[ch_axis] ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] + elif qscheme == torch.per_channel_affine_float_qparams: + ref_scales = per_channel_affine_float_qparams_ref_scales[ch_axis] + ref_zero_points = ref_min_vals[ch_axis] else: ref_scales = per_channel_affine_ref_scales[ch_axis] ref_zero_points = ( @@ -356,10 +368,10 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): if reduce_range: ref_scales = [s * 255 / 127 for s in ref_scales] ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] - - self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) + self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001)) self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) + # Test for serializability state_dict = myobs.state_dict() b = io.BytesIO() @@ -375,6 +387,7 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): self.assertEqual(myobs.max_vals, loaded_obs.max_vals) self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) + def test_observer_scriptable(self): obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()] for obs in obs_list: diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 2a7878078b30..b915dc2bc171 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -109,9 +109,10 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, torch.per_tensor_symmetric, torch.per_channel_affine, torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, ), "Default Observer only works for per_tensor_affine, \ - per_tensor_symmetric, per_channel_affine and \ - per_channel_symmetric quantization scheme" + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme" assert self.dtype in ( torch.qint8, torch.quint8, @@ -214,7 +215,8 @@ def _calculate_qparams(self, min_val, max_val): ) qmin, qmax = self._calculate_qmin_qmax() - + orig_min = min_val + orig_max = max_val min_val = torch.min(min_val, torch.zeros_like(min_val)) max_val = torch.max(max_val, torch.zeros_like(max_val)) @@ -232,6 +234,11 @@ def _calculate_qparams(self, min_val, max_val): zero_point = zero_point.new_full(zero_point.size(), (qmin + qmax) // 2) else: zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (orig_max - orig_min) / float(qmax - qmin) + scale_ones = torch.ones_like(scale) + scale = torch.where(scale != 0, scale, scale_ones) + zero_point = orig_min else: scale = (max_val - min_val) / float(qmax - qmin) scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype)) @@ -247,6 +254,9 @@ def _calculate_qparams(self, min_val, max_val): if len(zero_point.shape) == 0: # TODO: switch to zero_point.item() after adding JIT support zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype) + return scale, zero_point From d96c658f00ae34f459c54d0ec224c1357e01f6e3 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Wed, 12 Aug 2020 10:39:47 -0700 Subject: [PATCH 2/3] Update on "[quant] Make PerChannel Observer work with float qparams" Summary: Add implementation for new qscheme per_channel_affine_float_qparams in observer Test Plan: python test/test_quantization.py TestObserver.test_per_channel_observers Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23070633](https://our.internmc.facebook.com/intern/diff/D23070633) [ghstack-poisoned] --- test/quantization/test_workflow_module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index 0a9f7ec027c0..fde2d4c29747 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -369,7 +369,10 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): ref_scales = [s * 255 / 127 for s in ref_scales] ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001)) - self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1)) + if qscheme == torch.per_channel_affine_float_qparams: + self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=0.1)) + else: + self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) # Test for serializability From 096d67934cbe0a6ffa50d37352dbf3c919ae561a Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Wed, 12 Aug 2020 12:16:31 -0700 Subject: [PATCH 3/3] Update on "[quant] Make PerChannel Observer work with float qparams" Summary: Add implementation for new qscheme per_channel_affine_float_qparams in observer Test Plan: python test/test_quantization.py TestObserver.test_per_channel_observers Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23070633](https://our.internmc.facebook.com/intern/diff/D23070633) [ghstack-poisoned] --- test/quantization/test_workflow_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index fde2d4c29747..c8f2ffc0dfe4 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -370,7 +370,7 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001)) if qscheme == torch.per_channel_affine_float_qparams: - self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=0.1)) + self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1)) else: self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))