-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
conv_fused.py
422 lines (366 loc) · 15.5 KB
/
conv_fused.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import math
import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init
from torch.nn.modules.utils import _single, _pair
from torch.nn.parameter import Parameter
_BN_CLASS_MAP = {
1: nn.BatchNorm1d,
2: nn.BatchNorm2d,
3: nn.BatchNorm3d,
}
class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
_version = 2
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
dim=2):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
# note: below is actully for conv, not BN
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super(_ConvBnNd, self).reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def _forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# using zero bias here since the bias for original conv
# will be added later
if self.bias:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device())
conv = self._conv_forward(input, scaled_weight, zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
def extra_repr(self):
# TODO(jerryzh): extend
return super(_ConvBnNd, self).extra_repr()
def forward(self, input):
return self._forward(input)
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
# ===== Serialization version history =====
#
# Version 1/None
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- gamma : Tensor
# |--- beta : Tensor
# |--- running_mean : Tensor
# |--- running_var : Tensor
# |--- num_batches_tracked : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- bn : Module
# |--- weight : Tensor (moved from v1.self.gamma)
# |--- bias : Tensor (moved from v1.self.beta)
# |--- running_mean : Tensor (moved from v1.self.running_mean)
# |--- running_var : Tensor (moved from v1.self.running_var)
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version == 1:
# BN related parameters and buffers were moved into the BN module for v2
v2_to_v1_names = {
'bn.weight': 'gamma',
'bn.bias': 'beta',
'bn.running_mean': 'running_mean',
'bn.running_var': 'running_var',
'bn.num_batches_tracked': 'num_batches_tracked',
}
for v2_name, v1_name in v2_to_v1_names.items():
if prefix + v1_name in state_dict:
state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
state_dict.pop(prefix + v1_name)
elif prefix + v2_name in state_dict:
# there was a brief period where forward compatibility
# for this module was broken (between
# https://github.com/pytorch/pytorch/pull/38478
# and https://github.com/pytorch/pytorch/pull/38820)
# and modules emitted the v2 state_dict format while
# specifying that version == 1. This patches the forward
# compatibility issue by allowing the v2 style entries to
# be used.
pass
elif strict:
missing_keys.append(prefix + v2_name)
super(_ConvBnNd, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.bn.weight = bn.weight
qat_convbn.bn.bias = bn.bias
qat_convbn.bn.running_mean = bn.running_mean
qat_convbn.bn.running_var = bn.running_var
qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
class ConvBn1d(_ConvBnNd, nn.Conv1d):
r"""
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvBn1d
def __init__(self,
# Conv1d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm1d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
_ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _single(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig, dim=1)
class ConvBnReLU1d(ConvBn1d):
r"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvBnReLU1d
def __init__(self,
# Conv1d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm1d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias,
padding_mode, eps, momentum,
freeze_bn,
qconfig)
def forward(self, input):
return F.relu(ConvBn1d._forward(self, input))
@classmethod
def from_float(cls, mod):
return super(ConvBnReLU1d, cls).from_float(mod)
class ConvBn2d(_ConvBnNd, nn.Conv2d):
r"""
A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d`.
Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig, dim=2)
class ConvBnReLU2d(ConvBn2d):
r"""
A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvBnReLU2d
def __init__(self,
# Conv2d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias,
padding_mode, eps, momentum,
freeze_bn,
qconfig)
def forward(self, input):
return F.relu(ConvBn2d._forward(self, input))
@classmethod
def from_float(cls, mod):
return super(ConvBnReLU2d, cls).from_float(mod)
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nni.ConvReLU2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
qconfig=None):
super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode,
qconfig=qconfig)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = self.qconfig.weight()
def forward(self, input):
return F.relu(
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod):
return super(ConvReLU2d, cls).from_float(mod)
def update_bn_stats(mod):
if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]):
mod.update_bn_stats()
def freeze_bn_stats(mod):
if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]):
mod.freeze_bn_stats()