-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
quantization_mappings.py
197 lines (179 loc) · 7.24 KB
/
quantization_mappings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
from .fake_quantize import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)
from .utils import get_combined_dict
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.ConvTranspose1d: nnq.ConvTranspose1d,
nn.ConvTranspose2d: nnq.ConvTranspose2d,
nn.ELU: nnq.ELU,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
nn.GroupNorm: nnq.GroupNorm,
nn.Hardswish: nnq.Hardswish,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.LeakyReLU: nnq.LeakyReLU,
nn.Linear: nnq.Linear,
nn.ReLU6: nnq.ReLU6,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn1d: nnq.Conv1d,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS = {
nn.Conv2d: nnqat.Conv2d,
nn.Linear: nnqat.Linear,
# Intrinsic modules:
nni.ConvBn1d: nniqat.ConvBn1d,
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.GRUCell: nnqd.GRUCell,
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
}
# mapping from module to output activation post process class
DEFAULT_MODULE_TO_ACT_POST_PROCESS = {
nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant,
nn.Sigmoid: default_affine_fixed_qparams_fake_quant,
nn.Tanh: default_symmetric_fixed_qparams_fake_quant,
}
def get_default_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None):
r"""n Get the statically quantized module class corresponding to
the floating point module class
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return static_quant_module_class
def get_dynamic_quant_module_class(float_module_class, additional_dynamic_quant_mapping=None):
r"""n Get the dynamically quantized module class corresponding to
the floating point module class
"""
if additional_dynamic_quant_mapping is None:
additional_dynamic_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
assert dynamic_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return dynamic_quant_module_class
def get_default_qat_module_mappings():
''' Get default module mapping for quantization aware training
'''
return DEFAULT_QAT_MODULE_MAPPINGS
def get_default_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
def get_default_qconfig_propagation_list():
''' Get the default list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_default_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(str(float_op))
return quantized_op
def _get_special_act_post_process(module):
r""" Get the special activation post process for `module`, this has
higher priority than the activation post process in `qconfig`
e.g.
input: torch.nn.Sigmoid
output: default_affine_fixed_qparam_fake_quant
"""
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type(module), None)
def _has_special_act_post_process(module):
return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS