diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index e54eb33770c2..fb2f57282d79 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -1198,6 +1198,36 @@ def checkHooksIsPresent(model, before_convert=True): torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) +class TestEagerModeOps(QuantizationTestCase): + def _test_activation_op_impl( + self, float_module_class, quantized_module_class, extra_module_kwargs): + """ Implementation for testing common activation ops like leaky relu + Args: + extra_module_kwargs: keyword args to instantiate the float module + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_op = float_module_class(**extra_module_kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.activation_op(x) + x = self.dequant(x) + return x + + m = M().eval() + m.qconfig = default_qconfig + m = prepare(m) + self.checkObservers(m) + m = convert(m) + self.assertEqual(type(m.activation_op), quantized_module_class) + + def test_leaky_relu(self): + self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) + class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 53551efb7c0f..835dc6bf3083 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1158,6 +1158,9 @@ def test_hardswish(self): def test_elu(self): self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) + def test_leaky_relu(self): + self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) + def _test_norm_impl( self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): @@ -1392,7 +1395,6 @@ def __init__(self): self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) - self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() @@ -1417,11 +1419,6 @@ def forward(self, x): x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') - x = self.leaky_relu(x) - x = F.leaky_relu(x) - x = F.leaky_relu(x, inplace=True) - x = x.leaky_relu() - x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) diff --git a/test/test_quantization.py b/test/test_quantization.py index fc67891c24fe..d7d7e27fff53 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -45,6 +45,7 @@ from quantization.test_quantize import TestPostTrainingStatic # noqa: F401 from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401 from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401 +from quantization.test_quantize import TestEagerModeOps # noqa: F401 # TODO: merge with other tests in test_quantize.py? from quantization.test_quantize import TestFunctionalModule # noqa: F401 diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 844351a30def..5a995e103759 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -390,6 +390,7 @@ def convert(self, quantizer, node, load_arg, debug=False): ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } @register_quant_pattern(torch.nn.ELU) +@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish) @register_quant_pattern(torch.nn.InstanceNorm1d) @register_quant_pattern(torch.nn.InstanceNorm2d) @@ -398,6 +399,7 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.functional.hardswish) @register_quant_pattern(torch.nn.functional.instance_norm) @register_quant_pattern(torch.nn.functional.layer_norm) +@register_quant_pattern(torch.nn.functional.leaky_relu) class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' @@ -463,7 +465,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.Dropout) @register_quant_pattern(torch.nn.Hardsigmoid) @register_quant_pattern(torch.nn.Hardtanh) -@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.MaxPool1d) @register_quant_pattern(torch.nn.MaxPool2d) @register_quant_pattern(torch.nn.MaxPool3d) @@ -479,7 +480,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) -@register_quant_pattern(torch.nn.functional.leaky_relu) @register_quant_pattern(torch.nn.functional.max_pool1d) @register_quant_pattern(torch.nn.functional.max_pool2d) @register_quant_pattern(torch.nn.functional.max_pool3d) @@ -511,8 +511,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern('detach_') @register_quant_pattern('hardsigmoid') @register_quant_pattern('hardsigmoid_') -@register_quant_pattern('leaky_relu') -@register_quant_pattern('leaky_relu_') @register_quant_pattern('mean') @register_quant_pattern('numel') @register_quant_pattern('permute') diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 60d166ae4896..4102cef718fd 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -18,6 +18,7 @@ nn.ReLU6: nnq.ReLU6, nn.Hardswish: nnq.Hardswish, nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, nn.Conv1d: nnq.Conv1d, nn.Conv2d: nnq.Conv2d, nn.Conv3d: nnq.Conv3d, @@ -83,6 +84,7 @@ # mapping from floating point function or torch ops to quantized ops FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { F.elu: torch._ops.ops.quantized.elu, + F.leaky_relu: torch._ops.ops.quantized.leaky_relu, F.hardswish: torch._ops.ops.quantized.hardswish, F.instance_norm: torch._ops.ops.quantized.instance_norm, F.layer_norm: torch._ops.ops.quantized.layer_norm,