diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index 3ce7b020e38a..866e1971ab19 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -13,6 +13,7 @@ FixedQParamsFakeQuantize, default_debug_qconfig, default_observer, + default_histogram_observer, default_per_channel_weight_observer, default_affine_fixed_qparams_fake_quant, get_observer_dict, @@ -696,6 +697,29 @@ def test_observer_scriptable(self, qdtype, qscheme): loaded = torch.jit.load(buf) self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0])) +class TestHistogramObserver(QuantizationTestCase): + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), + qscheme=st.sampled_from( + (torch.per_tensor_affine, torch.per_tensor_symmetric)) + ) + def test_observer_scriptable(self, qdtype, qscheme): + ob_list = [ + HistogramObserver(dtype=qdtype, qscheme=qscheme), + default_histogram_observer() + ] + for obs in ob_list: + scripted = torch.jit.script(obs) + + x = torch.rand(3, 4) + obs(x) + scripted(x) + self.assertTrue(torch.equal(obs.histogram, scripted.histogram)) + buf = io.BytesIO() + torch.jit.save(scripted, buf) + buf.seek(0) + loaded = torch.jit.load(buf) + self.assertTrue(torch.equal(obs.histogram, scripted.histogram)) + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), reduce_range=st.booleans()) diff --git a/test/test_quantization.py b/test/test_quantization.py index 1c370913c6d0..8966bd052560 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -37,6 +37,7 @@ # TODO: merge with TestObserver # TODO: some tests belong to test_quantize.py, e.g. test_record_observer from quantization.test_workflow_module import TestRecordHistogramObserver # noqa: F401 +from quantization.test_workflow_module import TestHistogramObserver # noqa: F401 from quantization.test_workflow_module import TestDistributed # noqa: F401 # Workflow diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 2cc579f66087..8bab89e5b90c 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -697,8 +697,14 @@ class HistogramObserver(_ObserverBase): min_val: torch.Tensor max_val: torch.Tensor - def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8, - qscheme=torch.per_tensor_affine, reduce_range=False): + def __init__( + self, + bins: int = 2048, + upsample_rate: int = 128, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False + ): # bins: The number of bins used for histogram calculation. super(HistogramObserver, self).__init__(dtype=dtype, qscheme=qscheme, @@ -710,8 +716,75 @@ def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8, self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits self.upsample_rate = upsample_rate - @torch.jit.ignore - def _non_linear_param_search(self): + def _get_norm( + self, + delta_begin: torch.Tensor, + delta_end: torch.Tensor, + density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end + - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error( + self, next_start_bin: int, next_end_bin: int + ): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1) + dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width + + density = self.histogram / bin_width + + norm = torch.zeros(self.bins) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm(delta_begin, torch.ones(self.bins) * delta_end, density) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = ( + dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + ) + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: r"""Non-linear parameter search. An approximation for L2 error minimization for selecting min/max. @@ -719,74 +792,11 @@ def _non_linear_param_search(self): This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in caffe2/quantization/server/norm_minimization.cc """ - def _get_norm(delta_begin, delta_end, density, norm_type): - r""" - Compute the norm of the values uniformaly distributed between - delta_begin and delta_end. - - norm = density * (integral_{begin, end} x^2) - = density * (end^3 - begin^3) / 3 - """ - assert norm_type == "L2", "Only L2 norms are currently supported" - norm = 0.0 - if norm_type == "L2": - norm = ( - delta_end * delta_end * delta_end - - delta_begin * delta_begin * delta_begin - ) / 3 - return density * norm - - def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): - r""" - Compute the quantization error if we use start_bin to end_bin as the - min and max to do the quantization. - """ - bin_width = (self.max_val.item() - self.min_val.item()) / self.bins - - dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins - if dst_bin_width == 0.0: - return 0.0 - - src_bin = torch.arange(self.bins) - # distances from the beginning of first dst_bin to the beginning and - # end of src_bin - src_bin_begin = (src_bin - next_start_bin) * bin_width - src_bin_end = src_bin_begin + bin_width - - # which dst_bins the beginning and end of src_bin belong to? - dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1) - dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width - - dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1) - dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width - - density = self.histogram / bin_width - - norm = torch.zeros(self.bins) - - delta_begin = src_bin_begin - dst_bin_of_begin_center - delta_end = dst_bin_width / 2 - norm += _get_norm(delta_begin, delta_end, density, norm_type) - - norm += (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm( - -dst_bin_width / 2, dst_bin_width / 2, density, norm_type - ) - - dst_bin_of_end_center = ( - dst_bin_of_end * dst_bin_width + dst_bin_width / 2 - ) - - delta_begin = -dst_bin_width / 2 - delta_end = src_bin_end - dst_bin_of_end_center - norm += _get_norm(delta_begin, delta_end, density, norm_type) - - return norm.sum() - assert self.histogram.size()[0] == self.bins, "bins mistmatch" bin_width = (self.max_val - self.min_val) / self.bins # cumulative sum - total = sum(self.histogram) + total = torch.sum(self.histogram).item() cSum = torch.cumsum(self.histogram, dim=0) stepsize = 1e-5 # granularity @@ -825,7 +835,7 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): continue # calculate the quantization error using next_start_bin and next_end_bin - norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2") + norm = self._compute_quantization_error(next_start_bin, next_end_bin) if norm > norm_min: break @@ -837,11 +847,12 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): new_max = self.min_val + bin_width * (end_bin + 1) return new_min, new_max - @torch.jit.ignore - def _adjust_min_max(self, - combined_min: torch.Tensor, - combined_max: torch.Tensor, - upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + def _adjust_min_max( + self, + combined_min: torch.Tensor, + combined_max: torch.Tensor, + upsample_rate: int + ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: # We ensure that: # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) # This allows us to have a common grid of resolution s, where we can align @@ -849,7 +860,8 @@ def _adjust_min_max(self, # start_idx maps min_val to the histogram bin index. hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) - downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item()) + downsample_rate = int(torch.ceil( + (combined_max - combined_min) / (self.bins * hist_bin_width)).item()) e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min) # Relax only the max, not the min, so that for one sided distributions, min stays at zero combined_max = combined_max + e @@ -857,7 +869,6 @@ def _adjust_min_max(self, start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item()) return combined_min, combined_max, downsample_rate, start_idx - @torch.jit.ignore def _combine_histograms(self, orig_hist: torch.Tensor, new_hist: torch.Tensor, @@ -915,7 +926,8 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: assert combined_min.numel() == 1 and combined_max.numel() == 1, ( "histogram min/max values must be scalar." ) - combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max)) + combined_histogram = torch.histc( + x, self.bins, min=int(combined_min), max=int(combined_max)) if combined_min == min_val and combined_max == max_val: combined_histogram += self.histogram else: