Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Support sigmoid/hardsigmoid/tanh in qat #46738

Closed
wants to merge 7 commits into from
33 changes: 32 additions & 1 deletion test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
float_qparams_dynamic_qconfig,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
default_dynamic_quant_observer,
FixedQParamsFakeQuantize,
)

from torch.testing._internal.common_quantization import (
Expand Down Expand Up @@ -1247,6 +1248,36 @@ def forward(self, x):
def test_leaky_relu(self):
self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False})


class TestEagerModeQATOps(QuantizationTestCase):
def test_fixed_qparam_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
x = self.quant(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the approach here also work for cases where the scale and zero-point are not fixed? i.e for a hardTanh, the scale and zero-point depend on the arguments to init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, this allows to provide special fake quantize(activation_post_process) for a specific pattern (module/functional/torch op), could you write down the details, is it like following?

fake_quant = HardTanhFakeQuantize(hardtanh_instance.params)

I think this should be too hard to support in current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually for hardtanh, our implementation calls qclamp:

Tensor qnnpack_clamp(Tensor input, Scalar min, Scalar max) {
, which doesn't need LUTs. In this case, we should treat it like ReLU, i.e there is no fake-quant or observer needed at the output as there is no quantization that occurs inside the op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x = self.sigmoid(x)
x = self.hardsigmoid(x)
x = self.tanh(x)
x = self.dequant(x)
return x

m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
m = convert(m)
# make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))

class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
Expand Down
73 changes: 61 additions & 12 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
prepare,
prepare_qat,
convert,
FixedQParamsFakeQuantize,
)

# test utils
Expand Down Expand Up @@ -1406,17 +1407,14 @@ def test_general_value_ops(self):
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool3d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()

def forward(self, x):
x = self.conv(x)
Expand All @@ -1438,16 +1436,59 @@ def forward(self, x):
x = x.mean([2, 3], True)
x = F.interpolate(x, 4, mode='nearest')
x = F.interpolate(x, 4, mode='linear')
x = self.conv(x)
return x

# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(m, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)

# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)

@skipIfNoFBGEMM
def test_fixed_qparams_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()

def forward(self, x):
x = self.conv(x)
# F.sigmoid is deprecated
x = self.sigmoid(x)
x = torch.sigmoid(x)
x = x.sigmoid()
x.sigmoid_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
x = x.hardsigmoid()
x.hardsigmoid_()
x = self.sigmoid(x)
x = torch.sigmoid(x)
# F.sigmoid is deprecated
x = x.sigmoid()
x.sigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
Expand All @@ -1458,10 +1499,17 @@ def forward(self, x):

# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
m = M().train()
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(m, qconfig_dict)
qconfig_dict = {'': default_qat_qconfig}
prepared = prepare_qat_fx(m, qconfig_dict)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
# check the correct number of activation_post_process is inserted
count_check = {
ns.call_module(FixedQParamsFakeQuantize) : 13,
}
self.checkGraphModuleNodes(
prepared,
expected_node_occurrence=count_check)
# not runnable
quantized = convert_fx(prepared)

Expand All @@ -1476,6 +1524,7 @@ def forward(self, x):
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nn.Sigmoid),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
Expand Down
3 changes: 2 additions & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401
from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401
from quantization.test_quantize import TestEagerModeOps # noqa: F401
from quantization.test_quantize import TestEagerModeQATOps # noqa: F401

# TODO: merge with other tests in test_quantize.py?
from quantization.test_quantize import TestFunctionalModule # noqa: F401
Expand All @@ -64,7 +65,7 @@
from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401
from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401

# Tooling: numric_suite
# Tooling: numeric_suite
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401

# Backward Compatibility
Expand Down
22 changes: 16 additions & 6 deletions torch/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ def forward(self, X):

@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={},\
quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, \
scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
return 'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)

def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
Expand Down Expand Up @@ -226,9 +226,19 @@ def forward(self, X):
self.quant_max)
return X

@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point

@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.scale, self.zero_point, self.dtype,
self.quant_min, self.quant_max, self.qscheme)


default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
Expand Down
18 changes: 17 additions & 1 deletion torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,30 @@ def get_default_fusion_patterns():
return DEFAULT_FUSION_PATTERNS

DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_affine_fixed_qparam_fake_quant
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict()

# Register pattern for both static quantization and qat
def register_quant_pattern(pattern):
def register_quant_pattern(pattern, output_activation_post_process=None):
def insert(fn):
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
if output_activation_post_process is not None:
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process
return fn
return insert

# Get patterns for both static quantization and qat
def get_default_quant_patterns():
return DEFAULT_QUANTIZATION_PATTERNS

# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map():
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP


# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
Expand Down Expand Up @@ -62,6 +75,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False

Expand Down
34 changes: 21 additions & 13 deletions torch/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)

from ..quantization_mappings import (
get_static_quant_module_class,
Expand Down Expand Up @@ -464,6 +468,22 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)

@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
Expand All @@ -472,20 +492,16 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.AvgPool3d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.Hardsigmoid)
@register_quant_pattern(torch.nn.Hardtanh)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.nn.Sigmoid)
@register_quant_pattern(torch.nn.Tanh)
@register_quant_pattern(torch.adaptive_avg_pool1d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.hardsigmoid)
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
@register_quant_pattern(torch.nn.functional.interpolate)
Expand All @@ -505,11 +521,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.min)
@register_quant_pattern(torch.repeat_interleave)
@register_quant_pattern(torch.sigmoid)
@register_quant_pattern(torch.sort)
@register_quant_pattern(torch.squeeze)
@register_quant_pattern(torch.stack)
@register_quant_pattern(torch.tanh)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
Expand All @@ -518,8 +532,6 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('contiguous')
@register_quant_pattern('detach')
@register_quant_pattern('detach_')
@register_quant_pattern('hardsigmoid')
@register_quant_pattern('hardsigmoid_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
Expand All @@ -530,13 +542,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('reshape')
@register_quant_pattern('resize_')
@register_quant_pattern('shape')
@register_quant_pattern('sigmoid')
@register_quant_pattern('sigmoid_')
@register_quant_pattern('size')
@register_quant_pattern('squeeze')
@register_quant_pattern('squeeze_')
@register_quant_pattern('tanh')
@register_quant_pattern('tanh_')
@register_quant_pattern('transpose')
@register_quant_pattern('unsqueeze')
@register_quant_pattern('unsqueeze_')
Expand All @@ -547,7 +555,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_

# Default quantization handler, used for quantization of input and output
# of quantizable objects (e.g. modules and functionals)
class DefaultQuant(QuantizeHandler):
class DefaultQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
root_module = quantizer.modules['']
Expand Down