Skip to content

Commit

Permalink
[quant][graphmode][fx][eagermode] Add leaky relu support in quantizat…
Browse files Browse the repository at this point in the history
…ion workflows (#45712)

Summary:
Pull Request resolved: #45712

Eager mode will still be able to use functional leaky relu, but it will be less accurate than
LeakyReLU module.
FX graph mode will support both leaky relu functional and module

Test Plan: Imported from OSS

Reviewed By: z-a-f

Differential Revision: D24069961

fbshipit-source-id: 8d91c3c50c0bcd068ba3072378ebb4da9549be3b
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 6, 2020
1 parent fb50fca commit 0da6730
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
30 changes: 30 additions & 0 deletions test/quantization/test_quantize.py
Expand Up @@ -1198,6 +1198,36 @@ def checkHooksIsPresent(model, before_convert=True):
torch.quantization.convert(model, inplace=True)
checkHooksIsPresent(model, False)

class TestEagerModeOps(QuantizationTestCase):
def _test_activation_op_impl(
self, float_module_class, quantized_module_class, extra_module_kwargs):
""" Implementation for testing common activation ops like leaky relu
Args:
extra_module_kwargs: keyword args to instantiate the float module
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation_op = float_module_class(**extra_module_kwargs)
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.activation_op(x)
x = self.dequant(x)
return x

m = M().eval()
m.qconfig = default_qconfig
m = prepare(m)
self.checkObservers(m)
m = convert(m)
self.assertEqual(type(m.activation_op), quantized_module_class)

def test_leaky_relu(self):
self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False})

class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
Expand Down
9 changes: 3 additions & 6 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1158,6 +1158,9 @@ def test_hardswish(self):
def test_elu(self):
self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)

def test_leaky_relu(self):
self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)

def _test_norm_impl(
self, float_module, float_op, op_args, data, quantized_module, quantized_op,
skip_op_arg_for_functional=False):
Expand Down Expand Up @@ -1392,7 +1395,6 @@ def __init__(self):
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.leaky_relu = torch.nn.LeakyReLU()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
Expand All @@ -1417,11 +1419,6 @@ def forward(self, x):
x = x.mean([2, 3], True)
x = F.interpolate(x, 4, mode='nearest')
x = F.interpolate(x, 4, mode='linear')
x = self.leaky_relu(x)
x = F.leaky_relu(x)
x = F.leaky_relu(x, inplace=True)
x = x.leaky_relu()
x.leaky_relu_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Expand Up @@ -45,6 +45,7 @@
from quantization.test_quantize import TestPostTrainingStatic # noqa: F401
from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401
from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401
from quantization.test_quantize import TestEagerModeOps # noqa: F401

# TODO: merge with other tests in test_quantize.py?
from quantization.test_quantize import TestFunctionalModule # noqa: F401
Expand Down
6 changes: 2 additions & 4 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -390,6 +390,7 @@ def convert(self, quantizer, node, load_arg, debug=False):
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
}
@register_quant_pattern(torch.nn.ELU)
@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.Hardswish)
@register_quant_pattern(torch.nn.InstanceNorm1d)
@register_quant_pattern(torch.nn.InstanceNorm2d)
Expand All @@ -398,6 +399,7 @@ def convert(self, quantizer, node, load_arg, debug=False):
@register_quant_pattern(torch.nn.functional.hardswish)
@register_quant_pattern(torch.nn.functional.instance_norm)
@register_quant_pattern(torch.nn.functional.layer_norm)
@register_quant_pattern(torch.nn.functional.leaky_relu)
class DefaultNode(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
Expand Down Expand Up @@ -463,7 +465,6 @@ def convert(self, quantizer, node, load_arg, debug=False):
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.Hardsigmoid)
@register_quant_pattern(torch.nn.Hardtanh)
@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
Expand All @@ -479,7 +480,6 @@ def convert(self, quantizer, node, load_arg, debug=False):
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
@register_quant_pattern(torch.nn.functional.interpolate)
@register_quant_pattern(torch.nn.functional.leaky_relu)
@register_quant_pattern(torch.nn.functional.max_pool1d)
@register_quant_pattern(torch.nn.functional.max_pool2d)
@register_quant_pattern(torch.nn.functional.max_pool3d)
Expand Down Expand Up @@ -511,8 +511,6 @@ def convert(self, quantizer, node, load_arg, debug=False):
@register_quant_pattern('detach_')
@register_quant_pattern('hardsigmoid')
@register_quant_pattern('hardsigmoid_')
@register_quant_pattern('leaky_relu')
@register_quant_pattern('leaky_relu_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
Expand Down
2 changes: 2 additions & 0 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -18,6 +18,7 @@
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.ELU: nnq.ELU,
nn.LeakyReLU: nnq.LeakyReLU,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
Expand Down Expand Up @@ -83,6 +84,7 @@
# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
Expand Down

0 comments on commit 0da6730

Please sign in to comment.