Skip to content

Commit

Permalink
update HistogramObserver to be scriptable (#51081)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51081

Pull Request resolved: #51001

fix tests in TestQuantizeJitOps

Test Plan:
Imported from OSS
python test/test_quantization.py

Reviewed By: raghuramank100

Differential Revision: D26038759

Pulled By: lyoka

fbshipit-source-id: 0977ba7b8b26a9f654f20f5c698a7a20ec078c35
  • Loading branch information
yanli924 authored and facebook-github-bot committed Jan 27, 2021
1 parent 0a4bc72 commit ada9166
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 77 deletions.
24 changes: 24 additions & 0 deletions test/quantization/test_workflow_module.py
Expand Up @@ -13,6 +13,7 @@
FixedQParamsFakeQuantize,
default_debug_qconfig,
default_observer,
default_histogram_observer,
default_per_channel_weight_observer,
default_affine_fixed_qparams_fake_quant,
get_observer_dict,
Expand Down Expand Up @@ -696,6 +697,29 @@ def test_observer_scriptable(self, qdtype, qscheme):
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))

class TestHistogramObserver(QuantizationTestCase):
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from(
(torch.per_tensor_affine, torch.per_tensor_symmetric))
)
def test_observer_scriptable(self, qdtype, qscheme):
ob_list = [
HistogramObserver(dtype=qdtype, qscheme=qscheme),
default_histogram_observer()
]
for obs in ob_list:
scripted = torch.jit.script(obs)

x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.histogram, scripted.histogram))

@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Expand Up @@ -37,6 +37,7 @@
# TODO: merge with TestObserver
# TODO: some tests belong to test_quantize.py, e.g. test_record_observer
from quantization.test_workflow_module import TestRecordHistogramObserver # noqa: F401
from quantization.test_workflow_module import TestHistogramObserver # noqa: F401
from quantization.test_workflow_module import TestDistributed # noqa: F401

# Workflow
Expand Down
166 changes: 89 additions & 77 deletions torch/quantization/observer.py
Expand Up @@ -697,8 +697,14 @@ class HistogramObserver(_ObserverBase):
min_val: torch.Tensor
max_val: torch.Tensor

def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
qscheme=torch.per_tensor_affine, reduce_range=False):
def __init__(
self,
bins: int = 2048,
upsample_rate: int = 128,
dtype: torch.dtype = torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False
):
# bins: The number of bins used for histogram calculation.
super(HistogramObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
Expand All @@ -710,83 +716,87 @@ def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
self.upsample_rate = upsample_rate

@torch.jit.ignore
def _non_linear_param_search(self):
def _get_norm(
self,
delta_begin: torch.Tensor,
delta_end: torch.Tensor,
density: torch.Tensor
) -> torch.Tensor:
r"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
Currently only L2 norm is supported.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm

def _compute_quantization_error(
self, next_start_bin: int, next_end_bin: int
):
r"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins

dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
if dst_bin_width == 0.0:
return 0.0

src_bin = torch.arange(self.bins)
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width

# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1)
dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width

dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1)
dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width

density = self.histogram / bin_width

norm = torch.zeros(self.bins)

delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm += self._get_norm(delta_begin, torch.ones(self.bins) * delta_end, density)

norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
)

dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)

delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)

return norm.sum().item()

def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
caffe2/quantization/server/norm_minimization.cc
"""
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert norm_type == "L2", "Only L2 norms are currently supported"
norm = 0.0
if norm_type == "L2":
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm

def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
r"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins

dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
if dst_bin_width == 0.0:
return 0.0

src_bin = torch.arange(self.bins)
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width

# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1)
dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width

dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1)
dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width

density = self.histogram / bin_width

norm = torch.zeros(self.bins)

delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm += _get_norm(delta_begin, delta_end, density, norm_type)

norm += (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
)

dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)

delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm += _get_norm(delta_begin, delta_end, density, norm_type)

return norm.sum()

assert self.histogram.size()[0] == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins

# cumulative sum
total = sum(self.histogram)
total = torch.sum(self.histogram).item()
cSum = torch.cumsum(self.histogram, dim=0)

stepsize = 1e-5 # granularity
Expand Down Expand Up @@ -825,7 +835,7 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
continue

# calculate the quantization error using next_start_bin and next_end_bin
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
norm = self._compute_quantization_error(next_start_bin, next_end_bin)

if norm > norm_min:
break
Expand All @@ -837,27 +847,28 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
new_max = self.min_val + bin_width * (end_bin + 1)
return new_min, new_max

@torch.jit.ignore
def _adjust_min_max(self,
combined_min: torch.Tensor,
combined_max: torch.Tensor,
upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
def _adjust_min_max(
self,
combined_min: torch.Tensor,
combined_max: torch.Tensor,
upsample_rate: int
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.

hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item())
downsample_rate = int(torch.ceil(
(combined_max - combined_min) / (self.bins * hist_bin_width)).item())
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
# Relax only the max, not the min, so that for one sided distributions, min stays at zero
combined_max = combined_max + e
combined_min = combined_min
start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item())
return combined_min, combined_max, downsample_rate, start_idx

@torch.jit.ignore
def _combine_histograms(self,
orig_hist: torch.Tensor,
new_hist: torch.Tensor,
Expand Down Expand Up @@ -915,7 +926,8 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
assert combined_min.numel() == 1 and combined_max.numel() == 1, (
"histogram min/max values must be scalar."
)
combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max))
combined_histogram = torch.histc(
x, self.bins, min=int(combined_min), max=int(combined_max))
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
Expand Down

0 comments on commit ada9166

Please sign in to comment.