Skip to content

Commit

Permalink
[quant] Add quantized Sigmoid module (#45883)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45883

Test Plan:
python test/test_quantization.py TestStaticQuantizedModule.test_sigmoid

Imported from OSS

Reviewed By: z-a-f

Differential Revision: D24129116

fbshipit-source-id: aa960549509c60374012f35b1f5be39e90418099
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 7, 2020
1 parent 30bf799 commit 83d2c9a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
3 changes: 3 additions & 0 deletions test/quantization/test_quantized_module.py
Expand Up @@ -716,6 +716,9 @@ def test_elu(self):
def test_leaky_relu(self):
self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})

def test_sigmoid(self):
self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})

@given(
num_embeddings=st.integers(10, 50),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
Expand Down
3 changes: 2 additions & 1 deletion torch/nn/quantized/modules/__init__.py
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.nn.modules.pooling import MaxPool2d

from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
InstanceNorm2d, InstanceNorm3d
Expand Down Expand Up @@ -100,6 +100,7 @@ def from_float(mod):
'Hardswish',
'ELU',
'LeakyReLU',
'Sigmoid',
'LayerNorm',
'GroupNorm',
'InstanceNorm1d',
Expand Down
21 changes: 21 additions & 0 deletions torch/nn/quantized/modules/activation.py
Expand Up @@ -149,3 +149,24 @@ def _get_name(self):
def from_float(cls, mod):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)

class Sigmoid(torch.nn.Sigmoid):
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""

def __init__(self, output_scale: float, output_zero_point: int):
super().__init__()
self.output_scale = output_scale
self.output_zero_point = output_zero_point

def forward(self, input):
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)

@classmethod
def from_float(cls, mod):
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point))

0 comments on commit 83d2c9a

Please sign in to comment.