Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Use PlaceholderObserver as default dynamic quant observer #45343

Closed
wants to merge 16 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
7acbe20
[quant] Use PlaceholderObserver as default dynamic quant observer
jerryzh168 Sep 25, 2020
18de52a
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 25, 2020
e1a0d5f
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 25, 2020
5571ef0
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 25, 2020
bb36d3b
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 25, 2020
c2d65de
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 26, 2020
bf3e163
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 28, 2020
016aa53
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
073a222
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
4a720b9
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
34b9a1a
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
4153e46
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
dc9ae2f
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 29, 2020
9df4fd9
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 30, 2020
3981258
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 30, 2020
6e9abe7
Update on "[quant] Use PlaceholderObserver as default dynamic quant o…
jerryzh168 Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 1 addition & 77 deletions torch/quantization/observer.py
Expand Up @@ -475,82 +475,6 @@ def forward(self, x_orig):
self.max_val.copy_(max_val)
return x_orig


class MinMaxDynamicQuantObserver(MinMaxObserver):
r"""Observer module for computing the quantization parameters based on the
tensor min and max values in dynamic quantization.

This observer will mimic the quantization steps followed in the operator
to compute the activation tensor quantization parameters at run-time.

Args:
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit

.. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme

.. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.

.. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 0.1 and 0.
"""

@torch.jit.export
def calculate_qparams(self):
r"""Calculates the quantization parameters."""

if self.max_val == float('-inf') and self.min_val == float('inf'):
return torch.tensor([1.0]), torch.tensor([0])

assert self.min_val <= self.max_val, "min {} should be less than max {}".format(
self.min_val, self.max_val
)

if self.dtype == torch.qint8:
if self.reduce_range:
qmin, qmax = -64, 63
else:
qmin, qmax = -128, 127
else: # dtype == torch.quint8
if self.reduce_range:
qmin, qmax = 0, 127
else:
qmin, qmax = 0, 255

max_val, min_val = self.max_val.to(dtype=torch.float), self.min_val.to(dtype=torch.float)

# Extend the min_val and max_val to ensure that it contains 0.
min_val = torch.min(min_val, torch.tensor(0.).to(dtype=torch.float))
max_val = torch.max(max_val, torch.tensor(0.).to(dtype=torch.float))

scale = (max_val.to(dtype=torch.double) - min_val) / float(qmax - qmin)

if scale == 0.0 or torch.isinf(1.0 / scale):
scale = torch.tensor(0.1).to(dtype=torch.float)
zero_point = 0

zero_point_from_min = qmin - min_val / scale.to(dtype=torch.double)
zero_point_from_max = qmax - max_val / scale.to(dtype=torch.double)
zero_point_from_min_error = abs(qmin) - abs(min_val / scale.to(dtype=torch.double))
zero_point_from_max_error = abs(qmax) - abs(max_val / scale.to(dtype=torch.double))

if zero_point_from_min_error < zero_point_from_max_error:
initial_zero_point = zero_point_from_min
else:
initial_zero_point = zero_point_from_max

nudged_zero_point = 0

if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
nudged_zero_point = qmax
else:
nudged_zero_point = int(initial_zero_point.round())

return scale.to(dtype=torch.float), torch.tensor([nudged_zero_point])

class PerChannelMinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running per channel min and max values.
Expand Down Expand Up @@ -1099,7 +1023,7 @@ def calculate_qparams(self):
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
default_dynamic_quant_observer = MinMaxDynamicQuantObserver
default_dynamic_quant_observer = PlaceholderObserver.with_args(dtype=torch.float)
default_float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=torch.quint8,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)