目前学术界的量化方法都过于花俏，能落地的极少，工业界广泛使用的还是 Google TFLite 那一套量化方法 (即 TFLite 的量化，对应 Google 的论文 [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](https://arxiv.org/abs/1712.05877) )

这里介绍一下网络量化的基本原理，以及量化模型如何完成推理。

量化并不是什么新知识，我们在对图像做预处理时就用到了量化。回想一下，我们通常会将一张 uint8 类型、数值范围在 0~255 的图片归一成 float32 类型、数值范围在 0.0~1.0 的张量，这个过程就是**反量化**。类似地，我们经常将网络输出的范围在 0.0~1.0 之间的张量调整成数值为 0~255、uint8 类型的图片数据，这个过程就是**量化**。所以量化本质上只是对数值范围的重新调整，可以「粗略」理解为是一种线性映射。(之所以加「粗略」二字，是因为有些论文会用非线性量化，但目前在工业界落地的还都是线性量化，所以本文只讨论线性量化的方案)。

不过，可以明显看出，反量化一般没有信息损失，而量化一般都会有精度损失。这也非常好理解，float32 能保存的数值范围本身就比 uint8 多，因此必定有大量数值无法用 uint8 表示，只能四舍五入成 uint8 型的数值。量化模型和全精度模型的误差也来自四舍五入的 clip 操作。

In [2]:
import numpy s np

a = np.array(12)


用$r$表示浮点实数，$q$表示量化后的定点整数。浮点和整型之间的换算公式为:

$$
r = S(q-Z) \tag{1}
$$


$$
q = round(\frac{r}{S}+Z) \tag{2}
$$

$s$ 是 scale，表示实数和整数之间的比例关系.

$$
S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}}\tag{3}
$$


$z$是 zero point，表示实数中的 0 经过量化后对应的整数。

$$
Z = round(q_{max} - \frac{r_{max}}{S})\tag{4}
$$



In [9]:
#比如浮点数 的范围是[-0.2,0.2]
#要做int8的量化，那么Qmax=127     Qmin=-127
Qmax = 127
Qmin = -127
Rmin = -0.2
Rmax = 0.2

scale = (Rmax - Rmin) / (Qmax - Qmin)
zero_point = np.round(Qmax - Rmax/S)

print("scale:", scale)
print("zeropoint:",zero_point)
r1 = 0.135
q1 = np.round(r1/scale + zero_point)
print(r1,"的量化值为:",q1)


scale: 0.0015748031496062994
zeropoint: 0.0
0.135 的量化值为: 86.0


假设 $S_1$、$Z_1$ 是 $r_1$、$r_2$ 数对应的 scale 和 zero point， $S_2$、$Z_2$、$S_3$、$Z_3$同理

$$
S_{3}(q_{3}-Z_{3})=S_1(q_{1}-Z_{1})S_{2}(q_{2}-Z_{2})  \tag{6}
$$

那么
$$
q_{3}=\frac{S_{1} S_{2}}{S_{3}}(q_{1}-Z_{1})(q_{2}-Z_{2}) + Z_{3}  \tag{7}
$$



In [66]:
def quant(r,s,z):
    q = np.round( r/s  +  z)
    return q

Qmax = 127
Qmin = -127


R1min = -0.1
R1max = 0.2
r1 = 0.19
s1 =  (R1max - R1min) / (Qmax - Qmin)
z1 =  np.round(Qmax - R1max/s1)
q1 =quant(r1,s1,z1)



R2min = -0.15
R2max = 0.12
r2 = -0.112
s2 =  (R2max - R2min) / (Qmax - Qmin)
z2 =  np.round(Qmax - R2max/s2)
q2 =quant(r2,s2,z2)



R3min = -0.05
R3max = 0.05
r3 = r1 * r2
s3 =  (R3max - R3min) / (Qmax - Qmin)
z3 =  np.round(Qmax - R3max/s3)
q3 =quant(r3,s3,z3)

print(q1)
print(q2)
print(q3)


q3_1 = (s1 * s2 / s3) * ( q1 - z1 ) * ( q2 - z2 ) + z3

print(q3_1)
print(np.round(q3_1))


#量化的是否可以不一样呢？  比如q1为uint8量化   q2为int8量化  



119.0
-91.0
-54.0
-53.90964566929136
-54.0


In [56]:
Q1max = 127
Q1min = -127


R1min = -0.1
R1max = 0.2
r1 = 0.19
s1 =  (R1max - R1min) / (Q1max - Q1min)
z1 =  np.round(Q1max - R1max/s1)
q1 =quant(r1,s1,z1)


Q2max = 0
Q2min = 255

R2min = -0.15
R2max = 0.12
r2 = 0.112
s2 =  (R2max - R2min) / (Q2max - Q2min)
z2 =  np.round(Q2max - R2max/s2)
q2 =quant(r2,s2,z2)


Q3max = 127
Q3min = -127

R3min = -0.05
R3max = 0.05
r3 = r1 * r2
s3 =  (R3max - R3min) / (Q3max - Q3min)
z3 =  np.round(Q3max - R3max/s3)
q3 =quant(r3,s3,z3)

print(q1)
print(q2)
print(q3)


q3_1 = (s1 * s2 / s3) * ( q1 - z1 ) * ( q2 - z2 ) + z3
print(q3_1)
print(np.round(q3_1))

119.0
7.0
54.0
54.20964705882354
54.0


观察一下上面的公式，会发现除了$\frac{S_1 S_2}{S_3}$，浮点数运算，其他都是定点整数运算。要知道，在一些专用计算设备上是不支持浮点的计算的，那是否如何把$\frac{S_1 S_2}{S_3}$也变成定点运算呢？

假设 $M=\frac{S_1 S_2}{S_3}$，由于 $M$ 通常都是 (0, 1) 之间的实数 (这是通过大量实验统计出来的)，因此可以表示成 $M=2^{-n}M_0$，其中 $M_0$ 是一个定点实数。注意，定点数并不一定是整数，所谓定点，指的是小数点的位置是固定的，即小数位数是固定的。因此，如果存在 $M=2^{-n}M_0$，那我们就可以通过 $M_0$ 的 bit 位移操作实现 $2^{-n}M_0$，这样整个过程就都在定点上计算了。




假设 $P=7091$，$M=0.0072474273418460$ ($M$可以通过 $S$ 事先计算得到)，那下面我们就是要找到一个 $M0$ 和$n$，使得 $MP=2^{-n}M_0 P$ 成立。我们可以用一段代码来找到这两个数：

In [6]:
M = 0.0072474273418460
P = 7091

def multiply(n, M, P):
    result = M * P
    Mo = int(round(2 ** n * M)) # 这里不一定要四舍五入截断，因为python定点数不好表示才这样处理

    approx_result = (Mo * P) >> n
    print("n=%d, Mo=%d, approx=%f, error=%f"%\
          (n, Mo, approx_result, result-approx_result))

for n in range(1, 16):
    multiply(n, M, P)

n=1, Mo=0, approx=0.000000, error=51.391507
n=2, Mo=0, approx=0.000000, error=51.391507
n=3, Mo=0, approx=0.000000, error=51.391507
n=4, Mo=0, approx=0.000000, error=51.391507
n=5, Mo=0, approx=0.000000, error=51.391507
n=6, Mo=0, approx=0.000000, error=51.391507
n=7, Mo=1, approx=55.000000, error=-3.608493
n=8, Mo=2, approx=55.000000, error=-3.608493
n=9, Mo=4, approx=55.000000, error=-3.608493
n=10, Mo=7, approx=48.000000, error=3.391507
n=11, Mo=15, approx=51.000000, error=0.391507
n=12, Mo=30, approx=51.000000, error=0.391507
n=13, Mo=59, approx=51.000000, error=0.391507
n=14, Mo=119, approx=51.000000, error=0.391507
n=15, Mo=237, approx=51.000000, error=0.391507


可以看到，在 $n=11$、$M0=15$ 的时候，误差就已经在 1 以内了。因此，只要 $M_0P$的数值范围在 21(32-11) 个 bit 内，就可以通过对 $M_0P$右移 $n$ 个 bit 来近似 $MP$ 了，而这个误差本身在可以接受的范围内。这样一来，(8) 式就可以完全通过定点运算来计算，即我们实现了浮点矩阵乘法的量化。

$$M=2^{-n}M_0 $$

In [3]:
M = 0.5
multiper = 16384
shift = 15

print(16384 * 2**-15)

0.5


“定点数”就是“点”不动的数。那么究竟是什么“点”不动呢？没错，就是“小数点”。


add

In [None]:
multiper   和shift 怎么计算？？

$$
S_{3}(q_{3}-Z_{3})=S_1(q_{1}-Z_{1}) + S_{2}(q_{2}-Z_{2})  \tag{6}
$$

这里的话会有点那一理解

$$
q_{3}=\frac{S_{1}}{S_{3}}[(q_{1}-Z_{1})  + \frac{S_{2}}{S_{1}}(q_{2}-Z_{2})] + Z_{3}  \tag{7}
$$



https://zhuanlan.zhihu.com/p/336682366

因为 eltwisesum 要求两个数必须是在同一个量纲下才能相加。举个例子，假设其中一个数 r1 的范围是[-10, 10]，另一个数 r2 的范围是[-1, 1]，并且都是量化到uint8 [0, 255]。再假设 r1=5，那么可以算出 q1=190，假设 r2=0.5，同样可以算出 q2=190 (按照比例算一下就知道)，如果不做 rescale，那么定点中就变成 q1+q2=190+190，但这明明是两个不同的数，因此这样相加明显是不合理的，所以才要对其中一个输入做 rescale。

这涉及到乘法和加法本身的性质。沿用上面那个例子，r1*r2=S1(q1-Z1)S2(q2-Z2)，我们把S1和S2都乘进去了，因此在量化运算后，完全可以等价推算回浮点运算。但加法里面，r1+r2=S1(q1-Z1)+S2(q2-Z2)，因为加法本身的性质，我们没法避开S1 S2直接对q1和q2做量化运算，因此就会涉及到rescale。

In [12]:
import numpy as np

def quant(r,s,z):
    q = np.round( r/s  +  z)
    return q


Q1max = 127
Q1min = -127


R1min = -0.1
R1max = 0.2
r1 = 0.19
s1 =  (R1max - R1min) / (Q1max - Q1min)
z1 =  np.round(Q1max - R1max/s1)
q1 =quant(r1,s1,z1)


Q2max = 0
Q2min = 255

R2min = -0.15
R2max = 0.12
r2 = 0.112
s2 =  (R2max - R2min) / (Q2max - Q2min)
z2 =  np.round(Q2max - R2max/s2)
q2 =quant(r2,s2,z2)


Q3max = 127
Q3min = -127

R3min = -0.25
R3max = 0.32
r3 = r1 + r2
s3 =  (R3max - R3min) / (Q3max - Q3min)
z3 =  np.round(Q3max - R3max/s3)
q3 =quant(r3,s3,z3)

print(q1)
print(q2)
print(q3)


q3_1 =(s1/s3)*(q1 - z1 + (s2/s1)*(q2-z2))  + z3
print(q3_1)
print(np.round(q3_1))


print(r3)
print((q3_1-z3) * s3)


#注意一点  值的范围r1、r2的范围决定了r2的范围

119.0
7.0
119.0
118.75046439628483
119.0
0.302
0.3023927744326077


一个量化的conv，输入是uint8  weight是uint8    输出也应该是uint8, 中间的这个计算过程是什么样的呢？
$$
q_{3}=\frac{S_{1} S_{2}}{S_{3}}(q_{1}-Z_{1})(q_{2}-Z_{2}) + Z_{3} 
$$


再看一眼这个公式，$(q_{1}-Z_{1})(q_{2}-Z_{2})$ 可以直接计算出来，这里有个乘法的计算8bit x 8bit，所以输出是一个16bit来存储结果，再乘以multiper,完成移位移位，最后再加上zeropoint，就得到最后的结果了



### add

In [17]:
import numpy as np

scale0 = 0.26118746399879456          
zeropoint0 = 126.0

scale1 = 0.3498002290725708           
zeropoint1 = 140.0

output_scale = 0.40072473883628845 
output_zeropoint = 139

max_input_scale = max(scale0, scale1) 

input_multiper0 = scale0 / (2 * max_input_scale) 

input_multiper1 = scale1 / (2  * max_input_scale) 



data_left_shift = 4
real_output_multiplier = (max_input_scale*2) / ((0b0000000000000000000001 << data_left_shift) *  output_scale);
print(input_multiper0)
print(input_multiper1)
print(real_output_multiplier)

0.3733380402455478
0.5
0.1091148721215705


In [26]:
import math


def calcRescaleMultiAndshift(rescale):
    shift = 0
    exp = 0
    fr,exp = math.frexp(rescale)

    multiplier = np.round(fr * (0b00000000000000000000000000000000000000001 << 15))  #uint16

    if multiplier == (0b000000000000000000000000000000000001 << 15):
        multiplier =  multiplier / 2;
        exp = exp + 1;

    # assert(exp <= INT8_MAX && exp >= INT8_MIN);
    shift = exp   #static_cast<int8_t>
    if shift < -15:
        shift = 0;
        multiplier = 0;
    shift = -shift + 15    # add算子的shift不用-8   目前只有conv和fc类要-8

    return(multiplier,shift)
    
print(calcRescaleMultiAndshift(0.3733380402455478))
print(calcRescaleMultiAndshift(0.5))
print(calcRescaleMultiAndshift(0.1091148721215705))

(24467.0, 16)
(16384.0, 15)
(28604.0, 18)


sub   fm0_requantparam_zeropoint_in   减去zeropoint

mul   requantparam0_multiplier_in     乘multiper

srl   requantparam0_shift_in - 4      右移位 



sub  fm1_requantparam_zeropoint_in    减去zeropoint

mul  requantparam1_multiplier_in      乘multiper

srl  requantparam1_shift_in - 4       右移位 



add  两数相加


mul requantparam_multiplier_out     乘multiper   

srl requantparam_shift_out          右移位   

add requantparam_zerpoint_out       zeropoint加上   

In [7]:
def compute(a, b):
    a =((a - 140) * 16384) >> 15
    b =((b - 126) * 24467) >> 16
    return hex(( ((a + b) * 28604) >> 18) + 139)

input0 = [0x86, 0x8c, 0x90, 0x97, 0x9c, 0xb2, 0xa6, 0x83, 0x70, 0x7e, 0x91, 0x86]
input1 = [0x93, 0x4d, 0x91, 0x68, 0x74, 0x85, 0x94, 0x98, 0x88, 0x50, 0x8f, 0x67]

resul = [compute(a,b) for a,b in zip(input0,input1)]

expect_result = ['0x93','0x6b', '0x9b', '0x86', '0x92', '0xb1', '0xb0', '0x94', '0x79', '0x61', '0x9a', '0x77' ]

print("   python sim",resul)
print("expect_result",expect_result)

   python sim ['0x8b', '0x88', '0x8b', '0x8a', '0x8b', '0x8d', '0x8d', '0x8b', '0x89', '0x88', '0x8b', '0x89']
expect_result ['0x93', '0x6b', '0x9b', '0x86', '0x92', '0xb1', '0xb0', '0x94', '0x79', '0x61', '0x9a', '0x77']


In [None]:
              ['0x8b', '0x88', '0x8b', '0x8a', '0x8b', '0x8d', '0x8d', '0x8b', '0x89', '0x88', '0x8b', '0x89']

In [21]:
def compute(a,b):
    a =((a - 140) * 16384) >> 15
    b =((b - 126) * 24467) >> 16
    return hex(( ((a + b) * 28604) >> 18) + 139)

a_list = [0x93, 0x4d, 0x91, 0x68, 0x74, 0x85, 0x94, 0x98, 0x88, 0x50, 0x8f, 0x67]
b_list = [0x86, 0x8c, 0x90, 0x97, 0x9c, 0xb2, 0xa6, 0x83, 0x70, 0x7e, 0x91, 0x86]
resul1 = [compute(a,b) for a,b in zip(b_list,a_list)]

expect_result = ['0x93','0x6b', '0x9b', '0x86', '0x92', '0xb1', '0xb0', '0x94', '0x79', '0x61', '0x9a', '0x77' ]

print("   python sim",resul1)
print("expect_result",expect_result)

   python sim ['0x8b', '0x88', '0x8b', '0x8a', '0x8b', '0x8d', '0x8d', '0x8b', '0x89', '0x88', '0x8b', '0x89']
expect_result ['0x93', '0x6b', '0x9b', '0x86', '0x92', '0xb1', '0xb0', '0x94', '0x79', '0x61', '0x9a', '0x77']


In [None]:
sub

NameError: name 'sub' is not defined

subfm0_requantparam_zeropoint_in

mulrequantparam0_multiplier_in

srl(requantparam0_shift_in - 4)

    
sub fm1_requantparam_zeropoint_in

mul requantparam1_multiplier_in

srl (requantparam1_shift_in - 4)

      
sub 

      
mul  requantparam_multiplier_out

srl  requantparam_shift_out

add  requantparam_zerpoint_out


In [None]:
def sub_compute(a,b):
    a =((a - 140) * 16384) >> 15
    b =((b - 126) * 24467) >> 16
    return hex(( ((a - b) * 28604) >> 18) + 139)

### mul



sub fm1 - fm1_requantparam_zeropoin_in


sub fm0 - fm0_requantparam_zeropoin_in

mul  r1 r2 

mul requantparam_multiplier_out

srl requantparam_shift_out

add requantparam_zerpoin_out


In [24]:
def mul_compute(a,b):
   return ( ((fm1 - fm1_requantparam_zeropoin_in) * ((fm2 - fm2_requantparam_zeropoin_in)) * requantparam_multiplier_out) >> requantparam_shift_out) + requantparam_zerpoin_out
    
    

In [76]:
import numpy as np

import math


def calcRescaleMultiAndshift1(rescale):
    shift = 0
    exp = 0
    fr,exp = math.frexp(rescale)

    multiplier = np.round(fr * (0b00000000000000000000000000000000000000001 << 15))  #uint16

    if multiplier == (0b000000000000000000000000000000000001 << 15):
        multiplier =  multiplier / 2;
        exp = exp + 1;

    # assert(exp <= INT8_MAX && exp >= INT8_MIN);
    shift = exp   #static_cast<int8_t>
    if shift < -15:
        shift = 0;
        multiplier = 0;
    shift = -shift + 15     # add算子的shift不用-8   目前只有conv和fc类要-8
    
    return(multiplier,shift)



def quant(r,s,z):
    q = np.round( r/s  +  z)
    return q


Q1max = 127
Q1min = -127


R1min = -0.1
R1max = 0.2
r1 = 0.19
s1 =  (R1max - R1min) / (Q1max - Q1min)
z1 =  np.round(Q1max - R1max/s1)
q1 =quant(r1,s1,z1)


Q2max = 127
Q2min = -127

R2min = -0.15
R2max = 0.12
r2 = 0.112
s2 =  (R2max - R2min) / (Q2max - Q2min)
z2 =  np.round(Q2max - R2max/s2)
q2 =quant(r2,s2,z2)

q3_1 = (s1 * s2 / s3) * ( q1 - z1 ) * ( q2 - z2 ) + z3
print(q3_1)
# print(np.round(q3_1))


Q3max = 127
Q3min = -127

R3min = max(R1max,R2max) + min(R1min,R2min)
R3max = max(R1max*R2max,R1min* R2min)
r3 = r1 * r2
s3 =  (R3max - R3min) / (Q3max - Q3min)
z3 =  np.round(Q3max - R3max/s3)
q3 =quant(r3,s3,z3)

# print(q1,q2,q3)

print("q3:",q3)

multiper_float = s2*s3/s1
# print(multiper_float)
multiper,shift = calcRescaleMultiAndshift1(multiper_float)
print(multiper,shift)
shift = shift-4
(((int)(( q1 - z1 ) * ( q2 - z2 ) * multiper)) >> shift) + int(z3)

# print(shift)
# print(type(434137305))
# 3 + z3

113.83267716535437
q3: 153.0
-24730.0 28


336