In [1]:
## This source code is provided by TA.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])
def quantize_tensor(x, num_bits=8, min_val=None, max_val=None,imposePow=None,roundType='round'):
    if imposePow is not None:
        scale=2**(-1*imposePow)
    else:
        if not min_val and not max_val: 
            min_val, max_val = x.min(), x.max()
        if min_val!=max_val:
            scale=2**(torch.ceil(torch.log2((max_val-min_val)/(2.**(num_bits)-1.))))
        else:
            scale=2**(torch.ceil(torch.log2((torch.tensor([1e-4]))/(2.**(num_bits)-1.))))
    zero_point=0
    q_x = zero_point + x / scale
    qmin = (2.**(num_bits-1))*(-1.)
    qmax = 2.**(num_bits-1) - 1.
    if roundType=='round':
        q_x.clamp_(qmin, qmax).round_()
    elif roundType=='floor':
        q_x.clamp_(qmin, qmax).floor_()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, numBits=4,imposePow=None,roundType='round'):
        x = quantize_tensor(w,num_bits=numBits,imposePow=imposePow,roundType=roundType)
        x = dequantize_tensor(x)
        return x
    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

In [2]:
class ConvBN_Quant(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1,nBits=8):
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.nBits=nBits
        super(ConvBN_Quant, self).__init__()
        self.conv = nn.Conv2d(in_planes,out_planes,kernel_size,stride,padding)
        self.bn = nn.BatchNorm2d(out_planes)
        self.folded = False
        self.quantized = False

    def forward(self,x):
        r=self.bn.weight.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        var=self.bn.running_var.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        
        if not self.folded:
            w=self.conv.weight*r/(torch.sqrt(var+self.bn.eps))
            b=(self.conv.bias-self.bn.running_mean)*self.bn.weight/(torch.sqrt(self.bn.running_var+self.bn.eps))+self.bn.bias
        else:
            w=self.conv.weight
            b=self.conv.bias
        if not self.quantized:        
            w=FakeQuantOp.apply(w,self.nBits)
            b=FakeQuantOp.apply(b,self.nBits)

        x=F.conv2d(x,w,b,self.stride,self.padding)
        return x

    def fold(self):
        r=self.bn.weight.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        var=self.bn.running_var.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        self.conv.weight.data = self.conv.weight.data*r/(torch.sqrt(var+self.bn.eps))
        self.conv.bias.data = (self.conv.bias.data-self.bn.running_mean)*self.bn.weight.data/(torch.sqrt(self.bn.running_var+self.bn.eps))+self.bn.bias.data
        self.folded = True

    def foldAndQuantize(self):
        self.fold()
        self.quantized = True
        self.conv.weight.data = FakeQuantOp.apply(self.conv.weight.data,self.nBits)
        self.conv.bias.data = FakeQuantOp.apply(self.conv.bias.data,self.nBits)

class Linear_Quant(nn.Module):
    def __init__(self,in_features,out_features,nBits=8):
        super(Linear_Quant,self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.nBits = nBits
        self.fc = nn.Linear(in_features,out_features)
        self.quantized = False

    def forward(self,x):
        w = self.fc.weight
        b = self.fc.bias
        if not self.quantized:
            w = FakeQuantOp.apply(w,self.nBits)
            b = FakeQuantOp.apply(b,self.nBits)
        x=F.linear(x,w,b)
        return x

    def quantize(self):
        self.quantized = True
        self.fc.weight.data = FakeQuantOp.apply(self.fc.weight.data,self.nBits)
        self.fc.bias.data = FakeQuantOp.apply(self.fc.bias.data,self.nBits)

In [3]:
class ModelMnist_quant(nn.Module):
    def __init__(self,numNeurons=32,nBits=4):
        super(ModelMnist_quant, self).__init__()

        self.numNeurons = numNeurons
        self.nBits = nBits

        self.conv1 = ConvBN_Quant(1,3,3,stride=2,padding=0,nBits=nBits)#output 3*13*13
        self.fc1 = Linear_Quant(3*13*13,32,nBits=nBits)
        self.fc2 = Linear_Quant(32,32,nBits=nBits)
        self.fc3 = Linear_Quant(32,10,nBits=nBits)
        self.dropout=nn.Dropout(p=0.2)
      
    def forward(self, x):
        x = self.conv1(x)
        x=F.relu6(x)
        x=FakeQuantOp.apply(x,self.nBits,self.nBits-4,'floor')

        x=x.view(-1,3*13*13)
        x=self.dropout(x)
        x=self.fc1(x)
        x=F.relu6(x)
        x=FakeQuantOp.apply(x,self.nBits,self.nBits-4,'floor')
        
        x=self.fc2(x)
        x=F.relu6(x)
        x=FakeQuantOp.apply(x,self.nBits,self.nBits-4,'floor')
        
        x=self.fc3(x)
        x=FakeQuantOp.apply(x,self.nBits,self.nBits-4,'floor')
        
        return x

    def foldAndQuantize(self):
        self.conv1.foldAndQuantize()
        self.fc1.quantize()
        self.fc2.quantize()
        self.fc3.quantize()

In [4]:
from torchvision import datasets, transforms

batch_size = 256

dataset_all = datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       # add other preprocessing here
                       transforms.RandomRotation(5),
                       # E.g., transforms.RandomRotation(5),
                       transforms.Normalize((0.5,), (0.5,))
                   ]))
train_set, val_set = torch.utils.data.random_split(dataset_all, [50000, 10000]) # split the dataset
# construct the three data loader
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)    
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ])),
    batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Evaluate the Model on the Testing Set

In [6]:
model_q = torch.load("model.pt")
model_q.eval()
correctCount=0
for x,y in test_loader:
    x,y=x.to(device),y.to(device)
    pred=model_q(x).max(1)[1]
    correctCount+=torch.sum(pred==y).item()
testAcc=correctCount/len(test_loader.dataset)
print("Testing accuracy of the quantized model: ",testAcc)

Testing accuracy of the quantized model:  0.9667
