From 0da6730f029863b3c812e05ddb3cec8a316c7c1c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 6 Oct 2020 12:07:57 -0700 Subject: [PATCH] [quant][graphmode][fx][eagermode] Add leaky relu support in quantization workflows (#45712) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45712 Eager mode will still be able to use functional leaky relu, but it will be less accurate than LeakyReLU module. FX graph mode will support both leaky relu functional and module Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24069961 fbshipit-source-id: 8d91c3c50c0bcd068ba3072378ebb4da9549be3b --- test/quantization/test_quantize.py | 30 +++++++++++++++++++ test/quantization/test_quantize_fx.py | 9 ++---- test/test_quantization.py | 1 + .../quantization/fx/quantization_patterns.py | 6 ++-- torch/quantization/quantization_mappings.py | 2 ++ 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index e54eb33770c2..fb2f57282d79 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -1198,6 +1198,36 @@ def checkHooksIsPresent(model, before_convert=True): torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) +class TestEagerModeOps(QuantizationTestCase): + def _test_activation_op_impl( + self, float_module_class, quantized_module_class, extra_module_kwargs): + """ Implementation for testing common activation ops like leaky relu + Args: + extra_module_kwargs: keyword args to instantiate the float module + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_op = float_module_class(**extra_module_kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.activation_op(x) + x = self.dequant(x) + return x + + m = M().eval() + m.qconfig = default_qconfig + m = prepare(m) + self.checkObservers(m) + m = convert(m) + self.assertEqual(type(m.activation_op), quantized_module_class) + + def test_leaky_relu(self): + self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) + class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 53551efb7c0f..835dc6bf3083 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1158,6 +1158,9 @@ def test_hardswish(self): def test_elu(self): self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) + def test_leaky_relu(self): + self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) + def _test_norm_impl( self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): @@ -1392,7 +1395,6 @@ def __init__(self): self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) - self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() @@ -1417,11 +1419,6 @@ def forward(self, x): x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') - x = self.leaky_relu(x) - x = F.leaky_relu(x) - x = F.leaky_relu(x, inplace=True) - x = x.leaky_relu() - x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) diff --git a/test/test_quantization.py b/test/test_quantization.py index fc67891c24fe..d7d7e27fff53 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -45,6 +45,7 @@ from quantization.test_quantize import TestPostTrainingStatic # noqa: F401 from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401 from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401 +from quantization.test_quantize import TestEagerModeOps # noqa: F401 # TODO: merge with other tests in test_quantize.py? from quantization.test_quantize import TestFunctionalModule # noqa: F401 diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 844351a30def..5a995e103759 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -390,6 +390,7 @@ def convert(self, quantizer, node, load_arg, debug=False): ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } @register_quant_pattern(torch.nn.ELU) +@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish) @register_quant_pattern(torch.nn.InstanceNorm1d) @register_quant_pattern(torch.nn.InstanceNorm2d) @@ -398,6 +399,7 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.functional.hardswish) @register_quant_pattern(torch.nn.functional.instance_norm) @register_quant_pattern(torch.nn.functional.layer_norm) +@register_quant_pattern(torch.nn.functional.leaky_relu) class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' @@ -463,7 +465,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.Dropout) @register_quant_pattern(torch.nn.Hardsigmoid) @register_quant_pattern(torch.nn.Hardtanh) -@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.MaxPool1d) @register_quant_pattern(torch.nn.MaxPool2d) @register_quant_pattern(torch.nn.MaxPool3d) @@ -479,7 +480,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) -@register_quant_pattern(torch.nn.functional.leaky_relu) @register_quant_pattern(torch.nn.functional.max_pool1d) @register_quant_pattern(torch.nn.functional.max_pool2d) @register_quant_pattern(torch.nn.functional.max_pool3d) @@ -511,8 +511,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern('detach_') @register_quant_pattern('hardsigmoid') @register_quant_pattern('hardsigmoid_') -@register_quant_pattern('leaky_relu') -@register_quant_pattern('leaky_relu_') @register_quant_pattern('mean') @register_quant_pattern('numel') @register_quant_pattern('permute') diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 60d166ae4896..4102cef718fd 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -18,6 +18,7 @@ nn.ReLU6: nnq.ReLU6, nn.Hardswish: nnq.Hardswish, nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, nn.Conv1d: nnq.Conv1d, nn.Conv2d: nnq.Conv2d, nn.Conv3d: nnq.Conv3d, @@ -83,6 +84,7 @@ # mapping from floating point function or torch ops to quantized ops FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { F.elu: torch._ops.ops.quantized.elu, + F.leaky_relu: torch._ops.ops.quantized.leaky_relu, F.hardswish: torch._ops.ops.quantized.hardswish, F.instance_norm: torch._ops.ops.quantized.instance_norm, F.layer_norm: torch._ops.ops.quantized.layer_norm,