-
Notifications
You must be signed in to change notification settings - Fork 358
/
batch_norm_fold.py
628 lines (512 loc) · 26.6 KB
/
batch_norm_fold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Optimization code to fold batch-norm layers """
import contextlib
import math
from typing import List, Tuple, Union, Dict, Iterable, Set, Any
import numpy as np
import torch
import torch.nn
from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d
from torch.nn.modules.conv import _ConvTransposeNd
import aimet_common.libpymo as libpymo
from aimet_common.bias_correction import ConvBnPatternHandler, CONV_OP_TYPES, LINEAR_OP_TYPES, BN_OP_TYPES
from aimet_common.graph_pattern_matcher import PatternType
from aimet_common.graph_searcher import GraphSearcher
from aimet_common.utils import AimetLogger
# pylint: disable=unused-import
from aimet_torch.defs import PassThroughOp
from aimet_torch import utils
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.tensor_quantizer import LearnedGridTensorQuantizer
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)
LayerType = Union[
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.ConvTranspose2d,
]
_supported_layers = LayerType.__args__
BatchNormType = Union[BatchNorm1d, BatchNorm2d]
_supported_batchnorms = BatchNormType.__args__
def _delete_bn_from_model(model: torch.nn.Module, bn_layer_list: Iterable[BatchNormType]):
utils.replace_modules_with_instances_of_new_type(model, bn_layer_list, torch.nn.Identity)
@contextlib.contextmanager
def _expand_shape_to_4d(weight_tensor: libpymo.TensorParams):
""" Expand the shape of the weight into 4d. """
dims = len(weight_tensor.shape)
if dims > 5:
raise RuntimeError
if dims == 4:
yield weight_tensor
else:
orig_shape = weight_tensor.shape
if dims < 4:
# If we have less dimensions, we add 1s to make 4 dimensions
_4d_shape = np.append(orig_shape, [1 for _ in range(4-dims)]).astype(int)
else:
# If we have more dimensions, we concatenate all the dimensions beyond 3 into one dimension
_4d_shape = np.array(orig_shape[:3] + [math.prod(orig_shape[3:])])
try:
weight_tensor.shape = _4d_shape
yield weight_tensor
finally:
weight_tensor.shape = orig_shape
def _call_mo_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: BatchNormType,
fold_backward: bool):
"""
Calls C++ batch norm folding API.
:param weight: Weight or scale tensor to fold BN into.
:param bias: Bias tensor to fold BN into.
:param bn: Batch Norm layer
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
"""
with torch.no_grad():
bn_params = libpymo.BNParams()
bn_params.gamma = bn.weight.detach().cpu().numpy().reshape(-1)
bn_params.beta = bn.bias.detach().cpu().numpy().reshape(-1)
bn_params.runningMean = bn.running_mean.detach().cpu().numpy().reshape(-1)
sigma = torch.sqrt(bn.running_var + bn.eps)
bn_params.runningVar = sigma.detach().cpu().numpy().reshape(-1)
weight_tensor = libpymo.TensorParams()
weight_tensor.data = weight.detach().cpu().numpy().reshape(-1)
weight_tensor.shape = np.array(weight.shape)
bias_tensor = libpymo.TensorParams()
bias_tensor.data = bias.detach().cpu().numpy().reshape(-1)
bias_tensor.shape = np.array(bias.shape)
is_bias_valid = True
with _expand_shape_to_4d(weight_tensor):
_bias = libpymo.fold(bn_params, weight_tensor, bias_tensor, is_bias_valid, fold_backward)
bias.copy_(torch.tensor(_bias, device=bias.device, dtype=bias.dtype)
.reshape_as(bias))
weight.copy_(torch.tensor(weight_tensor.data, device=weight.device, dtype=weight.dtype)
.reshape_as(weight))
class _BatchNormFoldingNotSupported(RuntimeError):
pass
def _fold_to_scale(conv_wrapper: QcQuantizeWrapper, bn_wrapper: QcQuantizeWrapper):
"""
Fold BatchNorm into the scale and bias of the given layer.
:param conv_wrapper: QcQuantizeWrapper that wraps conv or linear layer.
:param bn_wrapper: QcQuantizeWrapper that wraps bn.
"""
# pylint: disable=protected-access, too-many-locals, too-many-branches, too-many-statements
conv = conv_wrapper._module_to_wrap
bn = bn_wrapper._module_to_wrap
weight_quantizer = conv_wrapper.param_quantizers["weight"]
if not isinstance(weight_quantizer, LearnedGridTensorQuantizer):
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale supports LearnedGridTensorQuantizer only; "
f"got {type(weight_quantizer)}."
)
output_quantizer = conv_wrapper.output_quantizers[0]
if output_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"BatchNorm should belong to the same supergroup with the layer to be folded to."
)
if "bias" in conv_wrapper.param_quantizers:
bias_quantizer = conv_wrapper.param_quantizers["bias"]
if bias_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"Can't fold BatchNorm to scale if bias quantizer is enabled."
)
encodings = weight_quantizer.encoding
if encodings is None:
raise RuntimeError
if isinstance(encodings, libpymo.TfEncoding):
encodings = [encodings]
if isinstance(conv, _ConvTransposeNd) and conv.groups != 1:
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale is not supported for grouped ConvTransposeNd."
)
# Add quantization noise to the BN params (bn weight & bn bias) before folding.
# NOTE: Quantization of foldable batchnorms is automatically disabled when
# initializing quantsim. However, it is still safer to call _quantize_params here
# as we can't guarantee this is always the case.
# For example, the user can manually enable quantization of batchnorms, etc...
# (FYI: _quantize_params takes effect only when the parameter quantizers are enabled)
with bn_wrapper._quantize_params():
_fold_to_weight(conv, bn, fold_backward=True)
gamma = bn.weight
sigma = torch.sqrt(bn.running_var + bn.eps)
new_encodings = []
for old_encoding, c in zip(encodings, gamma/sigma):
new_encoding = libpymo.TfEncoding()
new_encoding.delta = old_encoding.delta * abs(c)
if c >= 0:
new_encoding.max = old_encoding.max * c
new_encoding.min = old_encoding.min * c
else:
new_encoding.max = old_encoding.min * c
new_encoding.min = old_encoding.max * c
new_encoding.offset = old_encoding.offset
new_encoding.bw = old_encoding.bw
new_encodings.append(new_encoding)
weight_quantizer.encoding = new_encodings
# Copy batchnorm's output quantizers to conv output quantizers
for conv_output_quantizer, bn_output_quantizer in\
zip(conv_wrapper.output_quantizers, bn_wrapper.output_quantizers):
conv_output_quantizer.enabled = bn_output_quantizer.enabled
if bn_output_quantizer.encoding is not None:
encoding = libpymo.TfEncoding()
encoding.delta = bn_output_quantizer.encoding.delta
encoding.max = bn_output_quantizer.encoding.max
encoding.min = bn_output_quantizer.encoding.min
encoding.offset = bn_output_quantizer.encoding.offset
encoding.bw = bn_output_quantizer.encoding.bw
conv_output_quantizer.encoding = encoding
bn_output_quantizer.enabled = False
if "bias" not in conv_wrapper.param_quantizers:
bias_quantizer = LearnedGridTensorQuantizer(weight_quantizer.bitwidth,
weight_quantizer.round_mode,
weight_quantizer.quant_scheme,
weight_quantizer.use_symmetric_encodings,
enabled_by_default=False,
data_type=weight_quantizer.data_type)
bias_quantizer._ch_axis = weight_quantizer._ch_axis
conv_wrapper.param_quantizers["bias"] = bias_quantizer
def _fold_to_weight(conv_linear: LayerType, bn: BatchNormType, fold_backward: bool):
"""
Fold BatchNorm into the weight and bias of the given layer.
:param conv_linear: Conv or linear layer to fold BN into.
:param bn: BatchNorm to fold.
"""
# Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv
# However depthwise conv layers are always N, 1, H, W whether transposed-conv or not, so no need to transpose
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
if conv_linear.bias is None:
out_channels = conv_linear.out_features if isinstance(conv_linear, torch.nn.Linear)\
else conv_linear.out_channels
bias = torch.zeros(out_channels,
device=conv_linear.weight.device,
dtype=conv_linear.weight.dtype)
conv_linear.bias = torch.nn.Parameter(bias)
_call_mo_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)
# Transpose weight back to N, C, H, W for transposed Conv2D, for non-depthwise layers
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
def fold_given_batch_norms(model, layer_pairs):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param layer_pairs: Pairs of conv and batch_norm layers to use for folding
:return: None
"""
# pylint: disable=protected-access
conv_bn_pairs = []
bn_conv_pairs = []
def is_batchnorm(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_batchnorms)
def is_conv_linear(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_layers)
for x, y in layer_pairs:
if is_batchnorm(x):
assert is_conv_linear(y)
bn = x
conv = y
bn_conv_pairs.append((bn, conv))
else:
assert is_conv_linear(x)
assert is_batchnorm(y)
conv = x
bn = y
conv_bn_pairs.append((conv, bn))
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
def _fold_given_batch_norms(model,
conv_bn_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]],
bn_conv_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]]):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param conv_bn_pairs: List of (conv, bn) pairs to fold
:param bn_conv_pairs: List of (bn, conv) pairs to fold
:return: None
"""
# pylint: disable=protected-access
for bn, conv in bn_conv_pairs:
if isinstance(conv, QcQuantizeWrapper):
raise RuntimeError(f"Forward folding to scale is not possible. Got {conv}")
bn_modules = []
def _fold(conv, bn, fold_backward):
is_wrapped = isinstance(conv, QcQuantizeWrapper) or isinstance(bn, QcQuantizeWrapper)
try:
if is_wrapped:
assert isinstance(conv, QcQuantizeWrapper) and isinstance(bn, QcQuantizeWrapper)
_fold_to_scale(conv, bn)
bn_modules.append(bn._module_to_wrap)
else:
_fold_to_weight(conv, bn, fold_backward=fold_backward)
except _BatchNormFoldingNotSupported as e:
bn_name = utils.get_layer_name(model, bn)
conv_name = utils.get_layer_name(model, conv)
_logger.warning(
"Failed to fold %s to %s. [Reason] %s", bn_name, conv_name, str(e)
)
else:
bn_modules.append(bn._module_to_wrap if is_wrapped else bn)
with utils.in_eval_mode(model), torch.no_grad():
for conv, bn in conv_bn_pairs:
_fold(conv, bn, fold_backward=True)
for bn, conv in bn_conv_pairs:
_fold(conv, bn, fold_backward=False)
_delete_bn_from_model(model, bn_modules)
def find_all_batch_norms_to_fold(model, input_shapes, dummy_input: Union[torch.Tensor, Tuple] = None):
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param model: Model to search
:param input_shapes: Input shapes to use for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: List of pairs of bn and layers to fold bn into
"""
device = utils.get_device(model)
if dummy_input is not None:
connected_graph = ConnectedGraph(model, dummy_input)
else:
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
return conv_bn_pairs + bn_conv_pairs
def _find_all_batch_norms_to_fold(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]]]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
"""
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
return conv_bn_pairs, bn_conv_pairs, bn_to_fold
def _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]], Set]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
A set of bn ops which can be folded in to immediate convs.
"""
conv_linear_bn_activation_info_dict = find_all_conv_bn_with_activation_in_graph(connected_graph)
# To mark BN's already picked for backward folding
bn_picked_for_folding = set()
_conv_linear_optypes = CONV_OP_TYPES + LINEAR_OP_TYPES
ordered_conv_fc_modules = [op.get_module() for op in connected_graph.ordered_ops if op.type in _conv_linear_optypes]
conv_bn_pairs = []
# Backward fold is given priority over Forward fold
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, True):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.output_bn and bn_info.output_bn not in bn_picked_for_folding:
conv_bn_pairs.append((module, bn_info.output_bn.get_module()))
bn_picked_for_folding.add(bn_info.output_bn)
bn_conv_pairs = []
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, False):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.input_bn and bn_info.input_bn not in bn_picked_for_folding:
bn_conv_pairs.append((bn_info.input_bn.get_module(), module))
bn_picked_for_folding.add(bn_info.input_bn)
return conv_bn_pairs, bn_conv_pairs, bn_picked_for_folding
def find_standalone_batchnorm_ops(connected_graph: ConnectedGraph)->set:
"""
Find all batchnorms ops can not be folded.
:param connected_graph: Connected graph associated with the model.
:return stand_alone_bn_ops: Set of batchnorm ops can not be folded.
"""
_, _, bn_picked_for_folding = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
bn_ops = {op for op in connected_graph.get_all_ops().values() if op.type in BN_OP_TYPES}
stand_alone_bn_ops = bn_ops - bn_picked_for_folding
return stand_alone_bn_ops
def _is_valid_bn_fold(conv: LayerType, fold_backward: bool) -> bool:
"""
Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters
:param conv: The Conv/Linear layer to fold a BatchNorm into.
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
:return: True if a BatchNorm layer can be folded without causing output error.
"""
valid = True
if not fold_backward:
# Cannot fold BN -> Conv with padding. AIMET does not support forward folding to grouped or DW Conv
if isinstance(conv, (torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Conv3d)):
valid &= all(item == 0 for item in conv.padding)
valid &= conv.groups == 1
# AIMET does not support forward folding to ConvTranspose
elif isinstance(conv, torch.nn.ConvTranspose2d):
valid = False
else:
# AIMET does not support backwards folding to grouped ConvTranspose
if isinstance(conv, torch.nn.ConvTranspose2d):
valid &= conv.groups in (1, conv.in_channels)
return valid
def fold_all_batch_norms_to_weight(
model: torch.nn.Module,
input_shapes: Union[Tuple, List[Tuple]],
dummy_input: Union[torch.Tensor, Tuple] = None
) -> List[Tuple[LayerType, BatchNormType]]:
"""
Fold all batch_norm layers in a model into the weight of the corresponding conv layers
:param model: Model
:param input_shapes: Input shapes for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
if isinstance(model, torch.nn.DataParallel):
return fold_all_batch_norms_to_weight(model.module, input_shapes, dummy_input)
device = utils.get_device(model)
if dummy_input is None:
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
else:
inp_tensor_list = dummy_input
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_all_batch_norms_to_fold(connected_graph)
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
# Convert the standalone BNs which are not folded
bn_converted = convert_standalone_batchnorms(model, inp_tensor_list, bn_to_fold)
_logger.info("%d BatchNorms' weights got converted", len(bn_converted))
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def convert_standalone_batchnorms(model: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple],
folded_bn: set) -> List[Tuple[Any, BatchNorm2d]]:
"""
Convert the weights of all the standalone batchnorms of a model which didn't get folded.
:param model: torch model for which batch norm folding is being performed
:param dummy_input: dummy input for the model
:param folded_bn: list of BNs which got folded
:return: List of tuple(name, bn_module) whose weights got converted
"""
module_list = utils.get_ordered_list_of_modules(model, dummy_input)
bn_converted = []
for name, module in module_list:
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)) and module not in folded_bn:
convert_batchnorm_parameters(model, module)
_logger.debug("%s weights got converted", name)
bn_converted.append((name, module))
return bn_converted
def convert_batchnorm_parameters(model: torch.nn.Module, bn: Union[torch.nn.BatchNorm1d, torch.nn.BatchNorm2d]):
"""
To convert the weight of a batchnorm such that it becomes in the format y = weights*input + bias
:param model: torch model for which batch norm folding is being performed
:param bn: BatchNorm module whose weights needs to be converted
"""
with utils.in_eval_mode(model), torch.no_grad():
gamma = bn.weight
beta = bn.bias
running_mean = bn.running_mean
inv_sigma = torch.rsqrt(bn.running_var + bn.eps)
weight = gamma*inv_sigma
bias = beta - running_mean * weight
# Update the values
bn.eps = 0
bn.track_running_stats = False
bn.weight.copy_(weight.clone().detach())
bn.bias.copy_(bias.clone().detach())
bn.running_mean = torch.zeros(bn.running_mean.shape, device=bn.running_mean.device, dtype=bn.running_mean.dtype)
bn.running_var = torch.ones(bn.running_var.shape, device=bn.running_var.device, dtype=bn.running_var.dtype)
fold_all_batch_norms = fold_all_batch_norms_to_weight
def fold_all_batch_norms_to_scale(
sim: QuantizationSimModel,
) -> List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]:
"""
Fold all batch_norm layers in a model into the quantization scale parameter
of the corresponding conv layers
:param sim: QuantizationSimModel
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
# pylint: disable=protected-access
assert sim.model is not None
assert sim.connected_graph is not None
model = sim.model
connected_graph = sim.connected_graph
quant_wrappers = {
quant_wrapper._module_to_wrap: quant_wrapper
for _, quant_wrapper in sim.quant_wrappers()
}
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
conv_bn_pairs = [
(quant_wrappers[conv], quant_wrappers[bn]) for conv, bn in conv_bn_pairs
]
bn_conv_pairs = [
(quant_wrappers[bn], quant_wrappers[conv]) for bn, conv in bn_conv_pairs
]
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def find_all_conv_bn_with_activation(model: torch.nn.Module, input_shape: Tuple) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param model: PyTorch model
:param input_shape: shape of input to the model
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shape, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
return find_all_conv_bn_with_activation_in_graph(connected_graph)
def find_all_conv_bn_with_activation_in_graph(connected_graph: ConnectedGraph) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param connected_graph: ConnectedGraph object.
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
# initialize all patterns to be matched and associated call back functions
patterns_with_callbacks = []
layer_select_handler = ConvBnPatternHandler()
conv_types = ['Conv1d', 'Conv', 'ConvTranspose']
linear_types = ['Gemm']
for op_type in conv_types + linear_types:
patterns_with_callbacks.append(PatternType(pattern=['BatchNormalization', op_type],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=[op_type, 'BatchNormalization'],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['Conv3d', 'BatchNorm3d'], action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['BatchNorm3d', 'Conv3d'], action=layer_select_handler))
# create graph searcher instance with connected graph and patterns to search
graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)
# get all conv/linear and bn info
graph_searcher.find_all_patterns_in_graph_apply_actions()
convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict()
return convs_bn_activation_dict