-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
__init__.py
114 lines (94 loc) · 3.17 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from torch.nn.modules.pooling import MaxPool2d
from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
InstanceNorm2d, InstanceNorm3d
from .conv import Conv1d, Conv2d, Conv3d
from .conv import ConvTranspose1d, ConvTranspose2d
from .linear import Linear
from .embedding_ops import Embedding, EmbeddingBag
from .functional_modules import FloatFunctional, QFunctional
class Quantize(torch.nn.Module):
r"""Quantizes an incoming tensor
Args:
`scale`: scale of the output Quantized Tensor
`zero_point`: zero_point of output Quantized Tensor
`dtype`: data type of output Quantized Tensor
Attributes:
`scale`, `zero_point`, `dtype`
Examples::
>>> t = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> qt = qm(t)
>>> print(qt)
tensor([[ 1., -1.],
[ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, scale, zero_point, dtype):
super(Quantize, self).__init__()
self.register_buffer('scale', torch.tensor([scale]))
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.long))
self.dtype = dtype
def forward(self, X):
return torch.quantize_per_tensor(X, float(self.scale),
int(self.zero_point), self.dtype)
@staticmethod
def from_float(mod):
assert hasattr(mod, 'activation_post_process')
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
def extra_repr(self):
return 'scale={}, zero_point={}, dtype={}'.format(self.scale, self.zero_point, self.dtype)
class DeQuantize(torch.nn.Module):
r"""Dequantizes an incoming tensor
Examples::
>>> input = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> quantized_input = qm(input)
>>> dqm = DeQuantize()
>>> dequantized = dqm(quantized_input)
>>> print(dequantized)
tensor([[ 1., -1.],
[ 1., -1.]], dtype=torch.float32)
"""
def __init__(self):
super(DeQuantize, self).__init__()
def forward(self, Xq):
return Xq.dequantize()
@staticmethod
def from_float(mod):
return DeQuantize()
__all__ = [
'BatchNorm2d',
'BatchNorm3d',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'DeQuantize',
'Linear',
'MaxPool2d',
'Quantize',
'ReLU',
'ReLU6',
'Hardswish',
'ELU',
'LeakyReLU',
'Sigmoid',
'LayerNorm',
'GroupNorm',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'Embedding',
'EmbeddingBag',
# Wrapper modules
'FloatFunctional',
'QFunctional',
]