Skip to content

Commit

Permalink
[quant][graphmode][fx] Support sigmoid/hardsigmoid/tanh in qat (#46738)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #46738

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D24486972

fbshipit-source-id: c9f139bfdd54973da1a93a45e32937595dbe67fc
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 26, 2020
1 parent b5662ba commit e927b62
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 98 deletions.
33 changes: 32 additions & 1 deletion test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
float_qparams_dynamic_qconfig,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
default_dynamic_quant_observer,
FixedQParamsFakeQuantize,
)

from torch.testing._internal.common_quantization import (
Expand Down Expand Up @@ -1247,6 +1248,36 @@ def forward(self, x):
def test_leaky_relu(self):
self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False})


class TestEagerModeQATOps(QuantizationTestCase):
def test_fixed_qparam_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.sigmoid(x)
x = self.hardsigmoid(x)
x = self.tanh(x)
x = self.dequant(x)
return x

m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
m = convert(m)
# make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))

class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
Expand Down
99 changes: 80 additions & 19 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
convert,
PerChannelMinMaxObserver,
QConfigDynamic,
FixedQParamsFakeQuantize,
)

# test utils
Expand Down Expand Up @@ -1410,17 +1411,14 @@ def test_general_value_ops(self):
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool3d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()

def forward(self, x):
x = self.conv(x)
Expand All @@ -1442,21 +1440,6 @@ def forward(self, x):
x = x.mean([2, 3], True)
x = F.interpolate(x, 4, mode='nearest')
x = F.interpolate(x, 4, mode='linear')
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
x = x.hardsigmoid()
x.hardsigmoid_()
x = self.sigmoid(x)
x = torch.sigmoid(x)
# F.sigmoid is deprecated
x = x.sigmoid()
x.sigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
x = x.tanh()
x.tanh_()
x = self.conv(x)
return x

Expand Down Expand Up @@ -1488,6 +1471,84 @@ def forward(self, x):
expected_node_occurrence=count_check,
expected_node_list=order_check)

@skipIfNoFBGEMM
def test_fixed_qparams_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()

def forward(self, x):
x = self.conv(x)
# F.sigmoid is deprecated
x = self.sigmoid(x)
x = torch.sigmoid(x)
x = x.sigmoid()
x.sigmoid_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
x = x.hardsigmoid()
x.hardsigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
x = x.tanh()
x.tanh_()
x = self.conv(x)
return x

for eval_mode in [True, False]:
# This model is not executable since we just put all ops
# in the same forward
m = M()
if eval_mode:
m.eval()
qconfig = default_qconfig
prepare = prepare_fx
fq_count = 0
else:
m.train()
qconfig = default_qat_qconfig
prepare = prepare_qat_fx
fq_count = 13

# nothing to fuse so skipping the fuse step
qconfig_dict = {'': qconfig}
prepared = prepare(m, qconfig_dict)
# check the correct number of activation_post_process is inserted
count_check = {
ns.call_module(FixedQParamsFakeQuantize) : fq_count,
}
self.checkGraphModuleNodes(
prepared,
expected_node_occurrence=count_check)
# not runnable
quantized = convert_fx(prepared)

# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nn.Sigmoid),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)

def test_float_functional(self):
class TorchAdd(nn.Module):
"""Wrapper around torch.add so that all ops can be found at build"""
Expand Down
3 changes: 2 additions & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401
from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401
from quantization.test_quantize import TestEagerModeOps # noqa: F401
from quantization.test_quantize import TestEagerModeQATOps # noqa: F401

# TODO: merge with other tests in test_quantize.py?
from quantization.test_quantize import TestFunctionalModule # noqa: F401
Expand All @@ -64,7 +65,7 @@
from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401
from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401

# Tooling: numric_suite
# Tooling: numeric_suite
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401

# Backward Compatibility
Expand Down
22 changes: 16 additions & 6 deletions torch/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ def forward(self, X):

@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={},\
quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, \
scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
return 'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)

def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
Expand Down Expand Up @@ -226,9 +226,19 @@ def forward(self, X):
self.quant_max)
return X

@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point

@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.scale, self.zero_point, self.dtype,
self.quant_min, self.quant_max, self.qscheme)


default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
Expand Down
18 changes: 17 additions & 1 deletion torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,30 @@ def get_default_fusion_patterns():
return DEFAULT_FUSION_PATTERNS

DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_affine_fixed_qparam_fake_quant
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict()

# Register pattern for both static quantization and qat
def register_quant_pattern(pattern):
def register_quant_pattern(pattern, output_activation_post_process=None):
def insert(fn):
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
if output_activation_post_process is not None:
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process
return fn
return insert

# Get patterns for both static quantization and qat
def get_default_quant_patterns():
return DEFAULT_QUANTIZATION_PATTERNS

# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map():
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP


# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
Expand Down Expand Up @@ -62,6 +75,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False

Expand Down
34 changes: 21 additions & 13 deletions torch/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)

from ..quantization_mappings import (
get_static_quant_module_class,
Expand Down Expand Up @@ -487,6 +491,22 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)

@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
Expand All @@ -495,20 +515,16 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.AvgPool3d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.Hardsigmoid)
@register_quant_pattern(torch.nn.Hardtanh)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.nn.Sigmoid)
@register_quant_pattern(torch.nn.Tanh)
@register_quant_pattern(torch.adaptive_avg_pool1d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.hardsigmoid)
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
@register_quant_pattern(torch.nn.functional.interpolate)
Expand All @@ -528,11 +544,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.min)
@register_quant_pattern(torch.repeat_interleave)
@register_quant_pattern(torch.sigmoid)
@register_quant_pattern(torch.sort)
@register_quant_pattern(torch.squeeze)
@register_quant_pattern(torch.stack)
@register_quant_pattern(torch.tanh)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
Expand All @@ -541,8 +555,6 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('contiguous')
@register_quant_pattern('detach')
@register_quant_pattern('detach_')
@register_quant_pattern('hardsigmoid')
@register_quant_pattern('hardsigmoid_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
Expand All @@ -553,13 +565,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('reshape')
@register_quant_pattern('resize_')
@register_quant_pattern('shape')
@register_quant_pattern('sigmoid')
@register_quant_pattern('sigmoid_')
@register_quant_pattern('size')
@register_quant_pattern('squeeze')
@register_quant_pattern('squeeze_')
@register_quant_pattern('tanh')
@register_quant_pattern('tanh_')
@register_quant_pattern('transpose')
@register_quant_pattern('unsqueeze')
@register_quant_pattern('unsqueeze_')
Expand All @@ -570,7 +578,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_

# Default quantization handler, used for quantization of input and output
# of quantizable objects (e.g. modules and functionals)
class DefaultQuant(QuantizeHandler):
class DefaultQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
root_module = quantizer.modules['']
Expand Down

0 comments on commit e927b62

Please sign in to comment.