-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
fake_quantize.py
297 lines (245 loc) · 13.4 KB
/
fake_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import torch
from torch.nn import Module
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
import re
from abc import ABC, abstractmethod
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
class FakeQuantizeBase(ABC, Module):
r""" Base fake quantize module
Any fake quantize implementation should derive from this class.
Concrete fake quantize module should follow the same API. In forward, they will update
the statistics of the observed Tensor and fake quantize the input. They should also provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
"""
fake_quant_enabled: torch.Tensor
observer_enabled: torch.Tensor
def __init__(self):
super().__init__()
# fake_quant_enabled and observer_enabled are buffers to support their
# replication in DDP. Data type is uint8 because NCCL does not support
# bool tensors.
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
@torch.jit.export
def enable_fake_quant(self, enabled: bool = True) -> None:
self.fake_quant_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_fake_quant(self):
self.enable_fake_quant(False)
@torch.jit.export
def enable_observer(self, enabled: bool = True) -> None:
self.observer_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_observer(self):
self.enable_observer(False)
with_args = classmethod(_with_args)
class FakeQuantize(FakeQuantizeBase):
r""" Simulate the quantize and dequantize operations in training time.
The output of this module is given by
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
* :attr:`scale` defines the scale factor used for quantization.
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
* :attr:`quant_min` specifies the minimum allowable quantized value.
* :attr:`quant_max` specifies the maximum allowable quantized value.
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
statistics can still be updated.
* :attr:`observer_enable` controls statistics collection on tensors
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
quant_max should be chosen to be consistent with the dtype
Args:
observer (module): Module for observing statistics on input tensors and calculating scale
and zero-point.
quant_min (int): The minimum allowable quantized value.
quant_max (int): The maximum allowable quantized value.
observer_kwargs (optional): Arguments for the observer module
Attributes:
observer (Module): User provided module that collects statistics on the input tensor and
provides a method to calculate scale and zero-point.
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
super().__init__()
assert quant_min <= quant_max, \
'quant_min must be less than or equal to quant_max'
self.quant_min = quant_min
self.quant_max = quant_max
self.activation_post_process = observer(**observer_kwargs)
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
self.register_buffer('scale', torch.tensor([1.0]))
self.register_buffer('zero_point', torch.tensor([0]))
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \
if hasattr(self.activation_post_process, 'ch_axis') else -1
assert _is_per_channel(self.qscheme) or \
_is_per_tensor(self.qscheme), \
'Only per channel and per tensor quantization are supported in fake quantize' + \
' got qscheme: ' + str(self.qscheme)
self.is_per_channel = _is_per_channel(self.qscheme)
@torch.jit.export
def calculate_qparams(self):
return self.activation_post_process.calculate_qparams()
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
self.scale.resize_(_scale.shape)
self.scale.copy_(_scale)
self.zero_point.resize_(_zero_point.shape)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, float(self.scale), int(self.zero_point),
self.quant_min, self.quant_max)
return X
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
# specify serialization here.
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = self.scale
destination[prefix + 'zero_point'] = self.zero_point
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class FixedQParamsFakeQuantize(FakeQuantizeBase):
""" Simulate quantize and dequantize with fixed quantization
parameters in training time. Only per tensor quantization
is supported.
Args:
`scale` (float): fixed scale for the fake quantize module
`zero_point` (int): fixed zero point for the fake quantize module
`dtype`, `qscheme`, `quant_min`, `quant_max`
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self,
scale,
zero_point,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_min=0,
quant_max=255):
super().__init__()
assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
self.quant_min = quant_min
self.quant_max = quant_max
self.register_buffer('scale', torch.tensor([scale]))
self.register_buffer('zero_point', torch.tensor([zero_point]))
self.dtype = dtype
self.qscheme = qscheme
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
def forward(self, X):
if self.fake_quant_enabled[0] == 1:
X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
int(self.zero_point), self.quant_min,
self.quant_max)
return X
@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.scale, self.zero_point, self.dtype,
self.quant_min, self.quant_max, self.qscheme)
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
# TODO(future PR): remove these defaults and enforce activation functions
# to explicitly specify their output range
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0)
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=True)
def _is_fake_quant_script_module(mod):
''' Returns true if given mod is an instance of FakeQuantize script module.
'''
if isinstance(mod, torch.jit.RecursiveScriptModule):
# qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
suffix = mod._c.qualified_name.split('.', 1)[1]
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
return name == 'torch.quantization.fake_quantize.FakeQuantize'
return False
def disable_fake_quant(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_fake_quant()
def enable_fake_quant(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_fake_quant()
def disable_observer(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_observer()
def enable_observer(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_observer()