In [1]:
import numpy as np
import tensorflow as tf
import shutil
import numpy as np
import torch
from scipy import special

In [10]:
model = torch.load('./model-with-bn.pth')

### quantize model

In [149]:
def fix_quant(w, BITWIDTH = 7.0):
    ''' get the quantization value '''
    with torch.no_grad():
        fp_range = torch.max(torch.absolute(w))
        frac_digits = BITWIDTH-torch.ceil(torch.log2(fp_range))
        # stochastic rounding
        int_value = torch.floor(w*2.0**frac_digits)
        frac_value = w*2.0**frac_digits - int_value
        int_value += (frac_value>=torch.rand(int_value.shape)).float()
        
        fix_point_value = int_value*1.0/(2.0**frac_digits)
        return fix_point_value
        
def layer_quant_with_bn(conv_layer, bn_layer):
    with torch.no_grad():
        bn_gamma = bn_layer.weight.data
        bn_beta = bn_layer.bias.data
        bn_mean = bn_layer.running_mean
        bn_var = bn_layer.running_var
        eps = bn_layer.eps
        fold_factor = bn_gamma/torch.sqrt(bn_var+eps)
        fold_bias = bn_beta-bn_mean*fold_factor
        ''' transfer the quantization effect to beta in bn layer'''
        bn_layer.beta.data = fix_quant(fold_bias)+bn_mean*fold_factor

        fold_factor = torch.unsqueeze(fold_factor,1)
        fold_factor = torch.unsqueeze(fold_factor,2)
        fold_factor = torch.unsqueeze(fold_factor,3)
        fold_weight = conv_layer.weight.data*fold_factor
        ''' transfer the quantization effect to weight in conv layer'''
        conv_layer.weight.data = fix_quant(fold_weight)/fold_factor

def model_quant_with_bn(m):
    m=m.module
    
    layer_quant_with_bn(m._conv_stem,m._bn0)
    
    for b in m._blocks:
        if hasattr(b, '_expand_conv'):
            layer_quant_with_bn(b._expand_conv,b._bn0)
        layer_quant_with_bn(b._depthwise_conv,b._bn1)
        layer_quant_with_bn(b._project_conv,b._bn2)
    
    layer_quant_with_bn(m._conv_head,m._bn1)
    
    layer_quant(m._fc)

In [12]:
a=torch.Tensor([1.1,-1.1,1.9,-1.9])

In [13]:
a-a.type(torch.int).type(torch.float)

tensor([ 0.1000, -0.1000,  0.9000, -0.9000])

In [18]:
(torch.sign(a)==1).type(torch.float)

tensor([1., 0., 1., 0.])

In [19]:
torch.floor(a)

tensor([ 1., -2.,  1., -2.])

In [21]:
torch.clamp(a,-1,1)

tensor([ 1., -1.,  1., -1.])

In [172]:
t1=0
t2=0
for i in range(100):
    q = fix_quant(a)[2].numpy()
    if q == -1.0:
        t1+=1
    if q == -2.0:
        t2+=1

print(t1,t2)

92 8


In [24]:
torch.rand(a.shape)

tensor([0.0098, 0.3799])

In [88]:
2/128

0.015625