# 一、 PTQ
$$
r = S(q-Z)\\
S = \frac{r_{max}-r_{rmin}}{q_{max}-q_{min}} \\ \\

Z = round(q_{max}-\frac{r_{max}}{S})
$$

## 1. core formula

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os

In [18]:
def calcScaleZeroPoint(rmin, rmax, num_bits=8, signed=False):
    '''calculate S, Z parameters for certain range real numbers'''
    if signed:
        qmin = - 2. **(num_bits - 1)
        qmax = 2. **(num_bits - 1) - 1
    else:
        qmin = 0.
        qmax = 2. **(num_bits - 1)
    scale = float((rmax - rmin)/(qmax - qmin))
    zero_point = qmax - rmax / scale
    # handle overflow
    if zero_point < qmin:
        zero_point = qmin
    if zero_point > qmax:
        zero_point = qmax
    zero_point = int(zero_point)
    return scale, zero_point


In [19]:
rmax = 1
rmin = -1
scale, zero_point = calcScaleZeroPoint(rmin,rmax,num_bits=8,signed=False)
scale, zero_point

(0.015625, 64)

## 2. Quantizer for tensor

In [196]:
def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False):
    '''input: x, real value tensor
       output: q_x, quantized value tensor'''
    if signed:
        qmin = - 2. **(num_bits - 1)
        qmax = 2. **(num_bits - 1) - 1
    else:
        qmin = 0.
        qmax = 2. **(num_bits - 1)
    q_x = x / scale + zero_point
    q_x.clamp_(qmin, qmax).round_()
    if signed:
        return q_x.char()
    else:
        return q_x.byte() # convert to int32
def dequantize_tensor(q,scale,zero_point):
    r = scale*(q - zero_point)
    return r.float()
    

In [197]:
a = torch.randn(2,3,dtype=torch.float)
rmax = a.max()
rmin = a.min()
print(a)
print(rmax)
print(rmin)

tensor([[ 0.7974, -0.6490, -2.2510],
        [-1.3110, -0.8741, -1.9719]])
tensor(0.7974)
tensor(-2.2510)


In [198]:
scale, zero_point = calcScaleZeroPoint(rmin,rmax)
print(scale)
print(zero_point)

0.02381531521677971
94


In [200]:
q_a = quantize_tensor(a,scale,zero_point,signed=False)
print('origin value: ')
print(a)
print('quantized value: ')
print(q_a)
print('dequantized value: ')
print(dequantize_tensor(q_a,scale,zero_point))


origin value: 
tensor([[ 0.7974, -0.6490, -2.2510],
        [-1.3110, -0.8741, -1.9719]])
quantized value: 
tensor([[127,  67,   0],
        [ 39,  57,  11]], dtype=torch.uint8)
dequantized value: 
tensor([[0.7859, 5.4537, 3.8581],
        [4.7869, 5.2156, 4.1200]])


## 3. package as a class

In [192]:
class Qparam:
    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.scale = None
        self.zero_point = None
        self.rmin = None
        self.rmax = None
    
    def calcScaleZeroPoint(self, rmin, rmax, num_bits=8, signed=False):
        '''calculate S, Z parameters for certain range real numbers'''
        if signed:
            qmin = - 2. **(num_bits - 1)
            qmax = 2. **(num_bits - 1) - 1
        else:
            qmin = 0.
            qmax = 2. **(num_bits - 1)
        scale = float((rmax - rmin)/(qmax - qmin))
        zero_point = qmax - rmax / scale
        # handle overflow( have been handled in Qparam)
        # if zero_point < qmin:
        #     zero_point = qmin
        # if zero_point > qmax:
        #     zero_point = qmax

        zero_point = int(zero_point)
        return scale, zero_point

    def update(self, tensor):
        '''update all the params via input tensor'''
        if self.rmax is None or self.rmax < tensor.max():
            self.rmax = tensor.max()
        self.rmax = 0 if self.rmax < 0 else self.rmax # guarantee that zero_point won't overflow
        if self.rmin is None or self.rmin > tensor.min():
            self.rmin = tensor.min()
        self.rmin = 0 if self.rmin > 0 else self.rmin
        
        self.scale, self.zero_point = self.calcScaleZeroPoint(self.rmin,self.rmax,self.num_bits)
    
    def quantize_tensor(self, x, signed=False):
        '''input: x, real value tensor
           output: q_x, quantized value tensor'''
        if signed:
            qmin = - 2. **(self.num_bits - 1)
            qmax = 2. **(self.num_bits - 1) - 1
        else:
            qmin = 0.
            qmax = 2. **(self.num_bits - 1)
        q_x = x / self.scale + self.zero_point
        q_x.clamp_(qmin, qmax).round_()
        if signed:
            return q_x.char()
        else:
            return q_x.byte() # convert to int32
    
    def dequantize_tensor(self, q):
        r = self.scale*(q - self.zero_point)
        return r.float()


## 4. QNet base class

In [193]:
class QModule(nn.Module):
    def __init__(self, qi=True, qo=True, num_bits=8):
        super().__init__()
        if qi:
            self.qi = Qparam(num_bits) # quantizer for input tensor
        if qo:
            self.qo = Qparam(num_bits) # quantizer for output tensor
    
    def freeze(self):
        '''freeze params of the network'''
        pass

    def quantize_inference(self, x):
        raise NotImplementedError('quantize_inference must be implemented.')

## 5. Implement important QNet component

In [202]:
class QConv2d(QModule):

    def __init__(self, conv_module, qi=True, qo=True, num_bits=8):
        super(QConv2d,self).__init__(qi=qi, qo=qo, num_bits=num_bits) # quantizer for input output
        self.num_bits = num_bits
        self.conv_module = conv_module
        self.qw = Qparam(num_bits) # quantizer for weights

    def freeze(self,qi=None, qo=None):
        '''some hidden module may not need to calculate rmin/rmax but reuse former
        qo as its own qi. for conv2d layer, it has its own qi qo'''

        if hasattr(self, 'qi') and qi is not None:
            raise ValueError('qi has been provided in init function.')
        if not hasattr(self, 'qi') and qi is None:
            raise ValueError('qi is not existed, should be provided.')

        if hasattr(self, 'qo') and qo is not None:
            raise ValueError('qo has been provided in init function.')
        if not hasattr(self, 'qo') and qo is None:
            raise ValueError('qo is not existed, should be provided.')
        
        if qi is not None:
            self.qi = qi
        if qo is not None:
            self.qo = qo
        
        # TODO implement it by bit shift
        self.M = self.qw.scale * self.qi.scale / self.qo.scale # actually, it should be implement by bit shift
        self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
        # TODO expand the product
        self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point

        self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data,self.qw.scale * self.qi.scale, zero_point=0,signed=True) # since Z = 0, r=Sq, the value range of q must contain negative numbers

    # used in QAT
    def forward(self,x):
        # statistics and update
        if hasattr(self, 'qi'):
            self.qi.update(x)
            # simulate quantization effects of input
            x = self.qi.quantize_tensor(x)
            x = self.qi.dequantize_tensor(x)
        self.qw.update(self.conv_module.weight.data)
        # simulate quantization effects
        self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)

        x = self.conv_module(x)
        
        # qo's params maybe useful for latter layers
        if hasattr(self,'qo'):
            self.qo.update(x)
        
        return x

    def quantize_inference(self, x):
        # use original formula to calculate
        x = x - self.qi.zero_point
        # calculate in 8 bits integer
        x = self.conv_module(x)
        x = (self.M * x).round().int() # self.M is still in float format, need to be concert to int
        x = x +  self.qo.zero_point
        return x
        

In [155]:
class QLinear(QModule):

    def __init__(self, fc_module, qi, qo, num_bits=8):
        super().__init__(qi=qi, qo=qo, num_bits=num_bits)
        self.num_bits = num_bits
        self.fc_module = fc_module
        self.qw = Qparam(num_bits)

    def freeze(self,qi=None, qo=None):
        '''some hidden module may not need to calculate rmin/rmax but reuse former
        qo as its own qi. for conv2d layer, it has its own qi qo'''

        if hasattr(self, 'qi') and qi is not None:
            raise ValueError('qi has been provided in init function.')
        if not hasattr(self, 'qi') and qi is None:
            raise ValueError('qi is not existed, should be provided.')

        if hasattr(self, 'qo') and qo is not None:
            raise ValueError('qo has been provided in init function.')
        if not hasattr(self, 'qo') and qo is None:
            raise ValueError('qo is not existed, should be provided.')
        
        if qi is not None:
            self.qi = qi
        if qo is not None:
            self.qo = qo
        
        # TODO implement it by bit shift
        self.M = self.qw.scale * self.qi.scale / self.qo.scale # actually, it should be implement by bit shift
        self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
        # TODO expand the product
        self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point

        self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data,self.qw.scale * self.qi.scale, zero_point=0,signed=True) # since Z = 0, r=Sq, the value range of q must contain negative numbers

    # used in QAT
    def forward(self,x):
        # statistics and update
        if hasattr(self, 'qi'):
            self.qi.update(x)
            # simulate quantization effects of input
            x = self.qi.quantize_tensor(x)
            x = self.qi.deuantize_tensor(x)
        self.qw.update(self.fc_module.weight.data)
        # simulate quantization effects
        self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
        self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)

        x = self.fc_module(x) # no need to quantize bias 
        
        # qo's params maybe useful for latter layers
        if hasattr(self,'qo'):
            self.qo.update(x)
        
        return x

    def quantize_inference(self, x):
        # use original formula to calculate
        x = x - self.qi.zero_point
        # calculate in integer
        x = self.fc_module(x)
        x = (self.M * x).round().int()
        x = x + self.qo.zero_point
        return x
        

In [156]:
class QReLU(QModule):

    def __init__(self, relu_module, qi=False, num_bits=None):
        super(QReLU, self).__init__(qi=qi, num_bits=num_bits)
        self.relu_module = relu_module

    def freeze(self, qi=None):
        
        if hasattr(self, 'qi') and qi is not None:
            raise ValueError('qi has been provided in init function.')
        if not hasattr(self, 'qi') and qi is None:
            raise ValueError('qi is not existed, should be provided.')

        if qi is not None:
            self.qi = qi # relu module reuse former layers' qo

    def forward(self, x):
        if hasattr(self, 'qi'):
            self.qi.update(x)
            x = self.qi.quantize_tensor(x)
            x = self.qi.dequantize_tensor(x)
        
        x = self.relu_module(x)

        return x
    
    def quantize_inference(self, x):
       x = x.clone()
       x[x < self.qi.zero_point] = self.qi.zero_point
       return x

In [183]:
class QMaxPooling2d(QModule):

    def __init__(self, maxpooling2d_module, kernel_size=3, stride=1, padding=0, qi=False, num_bits=None):
        super(QMaxPooling2d, self).__init__(qi=qi, num_bits=num_bits)
        self.maxpooling2d_module = maxpooling2d_module
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def freeze(self, qi=None):
        if hasattr(self, 'qi') and qi is not None:
            raise ValueError('qi has been provided in init function.')
        if not hasattr(self, 'qi') and qi is None:
            raise ValueError('qi is not existed, should be provided.')
        if qi is not None:
            self.qi = qi

    def forward(self, x):
        if hasattr(self, 'qi'):
            self.qi.update(x)
            x = self.qi.quantize_tensor(x)
            x = self.qi.deuantize_tensor(x)

        x = self.maxpooling2d_module(x)

        return x
    
    def quantize_inference(self, x):
        x = self.maxpooling2d_module(x.float()).int() # max_pool2d_with_indices_cpu" not implemented for 'Int
        return x

## 6. Complete QNet

In [209]:
class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1,)
        self.fc = nn.Linear(5*5*40, 10)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.maxpool2d_1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.maxpool2d_2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.maxpool2d_1(x)
        x = self.relu2(self.conv2(x))
        x = self.maxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x
    
    def quantize(self, num_bits=8):
        self.qconv1 = QConv2d(self.conv1,qi=True,qo=True,num_bits=num_bits)
        self.qrelu1 = QReLU(self.relu1)
        self.qmaxpool2d_1 = QMaxPooling2d(self.maxpool2d_1)
        self.qconv2 = QConv2d(self.conv2,qi=False,qo=True,num_bits=num_bits)
        self.qrelu2 = QReLU(self.relu2)
        self.qmaxpool2d_2 = QMaxPooling2d(self.maxpool2d_2)
        self.qfc = QLinear(self.fc,qi=False,qo=True,num_bits=num_bits)
    
    # forward and update QParams
    def quantize_forward(self,x):
        x = self.qrelu1(self.qconv1(x))
        x = self.qmaxpool2d_1(x)
        x = self.qrelu2(self.qconv2(x))
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x
    
    def freeze(self):
        self.qconv1.freeze()
        self.qrelu1.freeze(qi=self.qconv1.qo)
        self.qmaxpool2d_1.freeze(qi=self.qconv1.qo)
        self.qconv2.freeze(qi=self.qconv1.qo)
        self.qrelu2.freeze(qi=self.qconv2.qo)
        self.qmaxpool2d_2.freeze(qi=self.qconv2.qo)
        self.qfc.freeze(qi=self.qconv2.qo)
    
    def quantize_inference(self, x):
        # input should be quantized, and then all the calculations are performed on integer
        qx = self.qconv1.qi.quantize_tensor(x)
        print('qx dtype: ', qx.dtype) 
        qx = self.qconv1.quantize_inference(qx)
        print('qx dtype: ', qx.dtype) 
        qx = self.qrelu1.quantize_inference(qx)
        print('qx dtype: ', qx.dtype) 
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        print('qx dtype: ', qx.dtype) 
        qx = self.qconv2.quantize_inference(qx)
        print('qx dtype: ', qx.dtype) 
        qx = self.qrelu2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)
        qx = self.qfc.quantize_inference(qx)
        qx = self.qfc.qo.dequantize_tensor(qx)
        return qx

## 7. Train the Net

In [33]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    lossLayer = torch.nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossLayer(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))

In [34]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += lossLayer(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
        test_loss, 100. * correct / len(test_loader.dataset)
    ))

In [35]:
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集，然后将其加载到内存中。"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = datasets.FashionMNIST(root="../Data",
                                                    train=True,
                                                    transform=trans,
                                                    download=True)
    mnist_test =  datasets.FashionMNIST(root="../Data",
                                                   train=False,
                                                   transform=trans,
                                                   download=True)
    return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=4),
            torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=4))

In [36]:
batch_size = 64
test_batch_size = 64
seed = 1
epochs = 15
lr = 0.01
momentum = 0.9
save_model = True
torch.manual_seed(seed) # fix the seed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)

In [37]:
model = Net(num_channels=1)
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)

In [38]:
for epoch in range(1, epochs + 1):
        train(model, device, train_iter, optimizer, epoch)
        test(model, device, test_iter)
if save_model:
    if not os.path.exists('ckpt'):
        os.makedirs('ckpt')
    torch.save(model.state_dict(), 'ckpt/mnist_cnn.pt')

Train Epoch: 1 [0/60000]	Loss: 2.314084
Train Epoch: 1 [3200/60000]	Loss: 1.864382
Train Epoch: 1 [6400/60000]	Loss: 1.077437
Train Epoch: 1 [9600/60000]	Loss: 0.674443
Train Epoch: 1 [12800/60000]	Loss: 1.032035
Train Epoch: 1 [16000/60000]	Loss: 0.687428
Train Epoch: 1 [19200/60000]	Loss: 0.851558
Train Epoch: 1 [22400/60000]	Loss: 0.774539
Train Epoch: 1 [25600/60000]	Loss: 0.802545
Train Epoch: 1 [28800/60000]	Loss: 0.730229
Train Epoch: 1 [32000/60000]	Loss: 0.795217
Train Epoch: 1 [35200/60000]	Loss: 0.853380
Train Epoch: 1 [38400/60000]	Loss: 0.704790
Train Epoch: 1 [41600/60000]	Loss: 0.655076
Train Epoch: 1 [44800/60000]	Loss: 0.727883
Train Epoch: 1 [48000/60000]	Loss: 0.612061
Train Epoch: 1 [51200/60000]	Loss: 0.579158
Train Epoch: 1 [54400/60000]	Loss: 0.561521
Train Epoch: 1 [57600/60000]	Loss: 0.893712

Test set: Average loss: 0.6126, Accuracy: 78%

Train Epoch: 2 [0/60000]	Loss: 0.524307
Train Epoch: 2 [3200/60000]	Loss: 0.446216
Train Epoch: 2 [6400/60000]	Loss: 0.6639

## 8. Post Training Quantization

In [162]:
import time

In [165]:
def direct_quantize(model, test_loader):
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_forward(data)
        if i % 200 == 0:
            break
    print('direct quantization finish')

def quantize_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))

def full_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))

In [210]:
model = Net()
model.load_state_dict(torch.load('./ckpt/mnist_cnn.pt'))
model.quantize(num_bits=8)

In [211]:
begin = time.time()
full_inference(model,test_iter)
end = time.time()
print('runtime: ', end- begin)


Test set: Full Model Accuracy: 88%

runtime:  4.376999616622925


In [212]:
direct_quantize(model,train_iter) # statistics the value of rmin rmax and updates scale zero_point
model.freeze()

direct quantization finish


In [207]:
model.qconv1.conv_module.bias.data.dtype

torch.int8

In [213]:
begin = time.time()
with torch.no_grad():
    quantize_inference(model,test_iter)
end = time.time()
print('runtime: ',end-begin)

qx dtype:  torch.uint8
qx dtype:  torch.int32
qx dtype:  torch.int32
qx dtype:  torch.int32


RuntimeError: expected scalar type Int but found Byte