Skip to content

Commit

Permalink
[reland][quant] Remove nn.quantized.ReLU module and nn.quantized.func…
Browse files Browse the repository at this point in the history
…tional.relu (#47415) (#48038)

Summary:
Pull Request resolved: #48038

nn.ReLU works for both float and quantized input, we don't want to define an nn.quantized.ReLU
that does the same thing as nn.ReLU, similarly for nn.quantized.functional.relu

this also removes the numerical inconsistency for models quantizes nn.ReLU independently in qat mode

Test Plan:
Imported from OSS

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D25000462

fbshipit-source-id: e3609a3ae4a3476a42f61276619033054194a0d2
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Nov 17, 2020
1 parent a03f05f commit 8aaca4b
Show file tree
Hide file tree
Showing 18 changed files with 96 additions and 127 deletions.
2 changes: 0 additions & 2 deletions docs/source/quantization-support.rst
Expand Up @@ -255,7 +255,6 @@ Quantized version of standard NN layers.
* :class:`~torch.nn.quantized.Conv3d` — 3D convolution
* :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer
* :class:`~torch.nn.MaxPool2d` — 2D max pooling
* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit
* :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at
quantized representation of 6
* :class:`~torch.nn.quantized.ELU` — ELU
Expand Down Expand Up @@ -294,7 +293,6 @@ quantization output parameters)
* :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler
* :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op
* :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling
* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit
* :func:`~torch.nn.quantized.functional.elu` — ELU
* :func:`~torch.nn.quantized.functional.hardsigmoid` — Hardsigmoid
* :func:`~torch.nn.quantized.functional.hardswish` — Hardswish
Expand Down
11 changes: 1 addition & 10 deletions docs/source/torch.nn.quantized.rst
@@ -1,14 +1,12 @@
torch.nn.quantized
------------------

This module implements the quantized versions of the nn layers such as
~`torch.nn.Conv2d` and `torch.nn.ReLU`.
This module implements the quantized versions of the nn modules and functionals.

Functional interface
~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.nn.quantized.functional

.. autofunction:: relu
.. autofunction:: linear
.. autofunction:: conv1d
.. autofunction:: conv2d
Expand All @@ -25,11 +23,6 @@ Functional interface

.. automodule:: torch.nn.quantized

ReLU
~~~~~~~~~~~~~~~
.. autoclass:: ReLU
:members:

ReLU6
~~~~~~~~~~~~~~~
.. autoclass:: ReLU6
Expand Down Expand Up @@ -119,5 +112,3 @@ InstanceNorm3d
~~~~~~~~~~~~~~~
.. autoclass:: InstanceNorm3d
:members:


26 changes: 24 additions & 2 deletions test/quantization/test_quantize.py
Expand Up @@ -307,8 +307,8 @@ def checkQuantized(model):
self.checkQuantDequant(model.sub)
self.checkQuantizedLinear(model.sub.module.fc1)
self.checkQuantizedLinear(model.sub.module.fc2)
self.assertEqual(type(model.sub.module.relu1), nnq.ReLU)
self.assertEqual(type(model.sub.module.relu2), nnq.ReLU)
self.assertEqual(type(model.sub.module.relu1), nn.ReLU)
self.assertEqual(type(model.sub.module.relu2), nn.ReLU)
self.checkScriptable(model, self.calib_data)
self.checkNoQconfig(model)

Expand Down Expand Up @@ -1249,6 +1249,9 @@ def forward(self, x):
def test_leaky_relu(self):
self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False})

def test_relu(self):
self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False})


class TestEagerModeQATOps(QuantizationTestCase):
def _test_activation_convert_numerics_impl(self, Act, data):
Expand Down Expand Up @@ -1326,6 +1329,25 @@ def test_leaky_relu(self):
data = torch.randn(1, 3, 2, 4)
self._test_activation_convert_numerics_impl(nn.LeakyReLU, data)

def test_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(x)
return x

m = M().train()
m.qconfig = default_qconfig
m = prepare_qat(m)
# make sure no activation_post_process is inserted for relu
self.assertFalse(hasattr(m, "activation_post_process"))
m = convert(m)
# make sure ReLU module is not changed
self.assertTrue(type(m.relu), nn.ReLU)

class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
Expand Down
42 changes: 38 additions & 4 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -16,13 +16,18 @@

from torch.quantization import (
QuantType,
QuantStub,
DeQuantStub,
quant_type_to_str,
default_qconfig,
default_dynamic_qconfig,
default_dynamic_quant_observer,
default_qat_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
fuse_modules,
prepare,
prepare_qat,
convert,
Expand Down Expand Up @@ -331,28 +336,57 @@ def __init__(self, dim, has_relu):
super().__init__()
self.conv = convs[dim](3, 3, 3)
self.bn = bns[dim](3)
self.relu = nn.ReLU()
self.relu = nn.ReLU() if has_relu else nn.Identity()
self.has_relu = has_relu
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
if self.has_relu:
x = self.relu(x)
x = self.dequant(x)
return x

options = itertools.product([1, 2], [True, False], self.static_quant_types)
options = itertools.product([2], [True, False], self.static_quant_types)
for dim, has_relu, quant_type in options:
expected_node = ns.call_module(
quantized_conv_relus[dim] if has_relu
else quantized_convs[dim])
self.checkGraphModeFxOp(
M(dim, has_relu),
m = M(dim, has_relu)
m_eager = copy.deepcopy(m)
result = self.checkGraphModeFxOp(
m,
self.img_data_dict[dim],
quant_type,
expected_node=expected_node,
)

# check numerics
qengine = torch.backends.quantized.engine
if quant_type == QuantType.STATIC:
m_eager.eval()
qconfig = get_default_qconfig(qengine)
prepare_fn = prepare
else:
m_eager.train()
qconfig = get_default_qat_qconfig(qengine)
prepare_fn = prepare_qat

fuse_list = ["conv", "bn"]
if has_relu:
fuse_list.append("relu")
fuse_modules(m_eager, fuse_list, inplace=True)
m_eager.qconfig = qconfig
m_eager = prepare_fn(m_eager)
m_eager(*self.img_data_dict[dim][0])
m_eager = convert(m_eager)
result_eager = m_eager(*self.img_data_dict[dim][0])
self.assertEqual(result, result_eager)


@skipIfNoFBGEMM
def test_dynamic_quant_fp16(self):
class Linear(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quantized_functional.py
Expand Up @@ -26,7 +26,7 @@ def test_relu_api(self):
zero_point = 1
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
qY = torch.relu(qX)
qY_hat = qF.relu(qX)
qY_hat = F.relu(qX)
self.assertEqual(qY, qY_hat)

def _test_conv_api_impl(
Expand Down
11 changes: 6 additions & 5 deletions test/quantization/test_quantized_module.py
Expand Up @@ -39,7 +39,7 @@

class TestStaticQuantizedModule(QuantizationTestCase):
def test_relu(self):
relu_module = nnq.ReLU()
relu_module = nn.ReLU()
relu6_module = nnq.ReLU6()

x = torch.arange(-10, 10, dtype=torch.float)
Expand Down Expand Up @@ -304,10 +304,11 @@ def _test_conv_api_impl(
check_save_load=True)

# Test from_float
conv_module.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(conv_module, inplace=True)
conv_module(X.float())
converted_qconv_module = torch.nn.Sequential(conv_module)
fused_conv_module = torch.nn.intrinsic._FusedModule(conv_module)
fused_conv_module.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(fused_conv_module, inplace=True)
fused_conv_module(X.float())
converted_qconv_module = fused_conv_module
torch.quantization.convert(converted_qconv_module, inplace=True)

# Smoke test to make sure the module actually runs
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_quantized_op.py
Expand Up @@ -227,14 +227,14 @@ def test_qrelu(self, X):
torch.relu,
torch.relu_,
torch.nn.functional.relu,
torch.nn.quantized.functional.relu,
torch.nn.functional.relu,
],
'reference_fn': torch.nn.functional.relu
},
{
'quantized_fn': [
torch.nn.functional.relu,
torch.nn.quantized.functional.relu,
torch.nn.functional.relu,
],
'reference_fn': torch.nn.functional.relu,
'extra_kwargs': {
Expand Down
17 changes: 0 additions & 17 deletions torch/nn/quantized/functional.py
Expand Up @@ -410,23 +410,6 @@ def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Ten
return torch.ops.quantized.celu(input, scale, zero_point, alpha)


def relu(input: Tensor, inplace: bool = False) -> Tensor:
r"""relu(input, inplace=False) -> Tensor
Applies the rectified linear unit function element-wise.
See :class:`~torch.nn.quantized.ReLU` for more details.
Args:
input: quantized input
inplace: perform the computation inplace
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.relu' must be quantized!")
if inplace:
return torch.relu_(input)
else:
return torch.relu(input)

def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False,
scale: Optional[float] = None, zero_point: Optional[int] = None):
r"""
Expand Down
3 changes: 1 addition & 2 deletions torch/nn/quantized/modules/__init__.py
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.nn.modules.pooling import MaxPool2d

from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
InstanceNorm2d, InstanceNorm3d
Expand Down Expand Up @@ -106,7 +106,6 @@ def from_float(mod):
'Linear',
'MaxPool2d',
'Quantize',
'ReLU',
'ReLU6',
'Sigmoid',
# Wrapper modules
Expand Down
38 changes: 0 additions & 38 deletions torch/nn/quantized/modules/activation.py
@@ -1,44 +1,6 @@
import torch
import torch.nn.quantized.functional

class ReLU(torch.nn.ReLU):
r"""Applies quantized rectified linear unit function element-wise:
:math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point.
Please see https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU
for more documentation on ReLU.
Args:
inplace: (Currently not supported) can optionally do the operation in-place.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.quantized.ReLU()
>>> input = torch.randn(2)
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
>>> output = m(input)
"""
def __init__(self, inplace=False):
super(ReLU, self).__init__(inplace)
self.inplace = inplace

def forward(self, input):
return torch.nn.quantized.functional.relu(input, inplace=self.inplace)

def _get_name(self):
return 'QuantizedReLU'

@staticmethod
def from_float(mod):
return ReLU(mod.inplace)


class ReLU6(torch.nn.ReLU):
r"""Applies the element-wise function:
Expand Down
10 changes: 3 additions & 7 deletions torch/nn/quantized/modules/batchnorm.py
Expand Up @@ -21,11 +21,9 @@ def _get_name(self):

@classmethod
def from_float(cls, mod):
activation_post_process = mod.activation_post_process
if type(mod) == nni.BNReLU2d:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
Expand All @@ -36,6 +34,7 @@ def from_float(cls, mod):
new_mod.zero_point = int(zero_point)
return new_mod

# TODO: dedup with BatchNorm2d
class BatchNorm3d(torch.nn.BatchNorm3d):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
"""
Expand All @@ -55,12 +54,9 @@ def _get_name(self):

@classmethod
def from_float(cls, mod):
activation_post_process = mod.activation_post_process
if type(mod) == nni.BNReLU3d:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process

scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
Expand Down
10 changes: 2 additions & 8 deletions torch/nn/quantized/modules/conv.py
Expand Up @@ -182,11 +182,9 @@ def from_float(cls, mod):
cls._FLOAT_MODULE.__name__
assert hasattr(mod, "qconfig"), \
"Input float module must have qconfig defined."
activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_CONV_RELU_MODULE:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
weight_post_process = mod.qconfig.weight()
return cls.get_qconv(mod, activation_post_process, weight_post_process)

Expand Down Expand Up @@ -449,13 +447,9 @@ def from_float(cls, mod):
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), \
'Input float module must have qconfig defined.'
# Workaround for sequential, ConvReLU3d should probably inherit from
# Conv3d instead
activation_post_process = mod.activation_post_process
if type(mod) == nni.ConvReLU3d:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
return cls.get_qconv(mod, activation_post_process)

# === Transposed Convolutions ===
Expand Down
4 changes: 1 addition & 3 deletions torch/nn/quantized/modules/linear.py
Expand Up @@ -252,11 +252,9 @@ def from_float(cls, mod):
assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process
if type(mod) == nni.LinearReLU:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
Expand Down

0 comments on commit 8aaca4b

Please sign in to comment.