In [2]:
"""
==========================
    Style: static quantize
    Model: VGG-16
    Create by: Han_yz @ 2020/1/29
    Email: 20125169@bjtu.edu.cn
    Github: https://github.com/Forggtensky
==========================
"""

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization


"""
------------------------------
    1、Model architecture
------------------------------
"""

class VGG(nn.Module):
    def __init__(self,features,num_classes=1000,init_weights=False):
        super(VGG,self).__init__()
        self.features = features  # 提取特征部分的网络，也为Sequential格式
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(  # 分类部分的网络
            nn.Linear(512*7*7,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,num_classes)
        )
        # add the quantize part
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x = self.quant(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x,start_dim=1)
        # x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module,nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias,0)
            elif isinstance(module,nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(module.bias,0)

cfgs = {
    'vgg11':[64,'M',128,'M',256,256,'M',512,512,'M',512,512,'M'],
    'vgg13':[64,64,'M',128,128,'M',256,256,'M',512,512,'M',512,512,'M'],
    'vgg16':[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M',512,512,512,'M'],
    'vgg19':[64,64,'M',128,128,'M',256,256,256,256,'M',512,512,512,512,'M',512,512,512,512,'M'],
}

def make_features(cfg:list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2,stride=2)]  #vgg采用的池化层均为2,2参数
        else:
            conv2d = nn.Conv2d(in_channels,v,kernel_size=3,padding=1)  #vgg卷积层采用的卷积核均为3,1参数
            layers += [conv2d,nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)  #非关键字的形式输入网络的参数

def vgg(model_name='vgg16',**kwargs):
    try:
        cfg = cfgs[model_name]
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(make_features(cfg),**kwargs)  # **kwargs为可变长度字典，保存多个输入参数
    return model

"""
------------------------------
    2、Helper functions
------------------------------
"""

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5


def run_benchmark(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # Run the scripted model on a few batches of images
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

def load_model(model_file):
    model_name = "vgg16"
    model = vgg(model_name=model_name,num_classes=1000,init_weights=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')


"""
------------------------------
    3. Define dataset and data loaders
------------------------------
"""

def prepare_data_loaders(data_path):
    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    print("dataset_train : %d" % (len(dataset)))

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    print("dataset_test : %d" % (len(dataset_test)))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

# Specify random seed for repeatable results
torch.manual_seed(191009)

data_path = 'data/imagenet_1k'
saved_model_dir = 'model/'
float_model_file = 'vgg16_pretrained_float.pth'
scripted_float_model_file = 'vgg16_quantization_scripted.pth'
scripted_default_quantized_model_file = 'vgg16_quantization_scripted_default_quantized.pth'
scripted_optimal_quantized_model_file = 'vgg16_quantization_scripted_optimal_quantized.pth'

train_batch_size = 30
eval_batch_size = 30

data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to('cpu')

print('\n Before quantization: \n',float_model)
float_model.eval()

# Note: vgg-16 has no BN layer so that not need to fuse model

num_eval_batches = 10

print("Size of baseline model")
print_size_of_model(float_model)

# to get a “baseline” accuracy, see the accuracy of our un-quantized model
top1, top5 = evaluate(float_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file) # save un_quantized model

"""
------------------------------
    4. Post-training static quantization
------------------------------
"""

num_calibration_batches = 40

myModel = load_model(saved_model_dir + float_model_file).to('cpu')



# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
myModel.qconfig = torch.quantization.default_qconfig
print('\n myModel.qconfig: \n',myModel.qconfig)
torch.quantization.prepare(myModel, inplace=True)

# Calibrate with the training set
print('\nPost Training Quantization Prepare: Inserting Observers by Calibrate')
evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
print("Calibrate done")

# Convert to quantized model
torch.quantization.convert(myModel, inplace=True)
print('Post Training Quantization: Convert done')


print('\n After quantization: \n',myModel)



dataset_train : 1000
dataset_test : 1000

 Before quantization: 
 VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)


* ### save quantified data

In [3]:
# print('\n\nmodel_int8 conv keys:',per_channel_quantized_model.state_dict().keys())
# import tqdm
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

def getrealzero(_scale : float, _zero_point : int) -> int:
    return int(torch.int_repr(torch.quantize_per_tensor(torch.tensor(0.), float(_scale), int(_zero_point), torch.qint8)).numpy())
    # return round(-_zero_point / _scale)

class hbm_channel_mem(): # each channel save 32 bits data
    data_mem = []
    zeropoint_mem = []
    def __init__(self):
        # _list_tmp = list()
        # for iter in range(0, 32):
            # self.data_mem.append(_list_tmp)
        self.data_mem = [[] for x in range(0, 32)]
        self.zeropoint_mem = [[] for x in range(0, 32)]

    # def set(self, tunnel, addr, value):
    #     data_mem
    def autofill(self):
        for position in range(1, 32):
            for addr in range(len(self.data_mem[position]), len(self.data_mem[0])):
                self.data_mem[position].append(self.zeropoint_mem[0][addr])
            # while (len(self.data_mem[0]) > len(self.data_mem[position])):
            #     self.data_mem[position].append(self.zeropoint_mem[0][len(self.data_mem[position])]) # change to real zero point

    def print(self, maxaddrlimit = 64):
        maxaddr = 0
        for position in range(0, 32):
            maxaddr = max(maxaddr, len(self.data_mem[position]))

        if maxaddrlimit != -1:
            maxaddr = min(maxaddr, maxaddrlimit)

        for addr in range(0, maxaddr):
            for position in range(0, 32):
                if len(self.data_mem[position]) <= addr:
                    print('NUL\t', end = '')
                else:
                    print(f'{self.data_mem[position][addr]}\t', end = '')
            print()

    def append(self, position : int, value : int, zeropoint : int):
        # print('channel append enter')
        # print(self)
        # print(position, value)
        # print('channel append exit')
        self.data_mem[position].append(value)
        self.zeropoint_mem[position].append(zeropoint)

    def appends(self, valuelist : list):
        for iter in range(0, 32):
            self.data_mem[iter].append(valuelist[iter])

    def save(self, filepath : str, saveaspcieorder = True):
        self.autofill()
        # self.print()
        if saveaspcieorder:
            with open(filepath, 'wb+') as f:
                for addr_base in range(0, len(self.data_mem[0]), 8):
                    for addr in range(addr_base + 7, addr_base - 1, -1):
                        for position in range(0, 32):
                            # print(type(self.data_mem[addr][position]))
                            try:
                                # print(addr, ' ', position)
                                # print(self.data_mem[position][addr])
                                # print(int(self.data_mem[position][addr]))
                                # print(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                                f.write(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                            except:
                                
                                print(self.data_mem[position][addr])
                                raise
                        # f.write(int.from_bytes(self.data_mem[addr][position].to_bytes(1, 'little', signed = True), 'little', signed = False))
                    # raise
                    
                f.close()
        else:
            with open(filepath, 'wb+') as f:
                for addr in range(len(self.data_mem[0])):
                    
                    for position in range(31, -1, -1):
                        # print(type(self.data_mem[addr][position]))
                        try:
                            # print(self.data_mem[position][addr])
                            # print(int(self.data_mem[position][addr]))
                            # print(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                            f.write(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                        except:
                            
                            print(self.data_mem[position][addr])
                            raise
                        # f.write(int.from_bytes(self.data_mem[addr][position].to_bytes(1, 'little', signed = True), 'little', signed = False))
                    # raise
                    
                f.close()

class hbm_mem():
    hbm_data_mem = []
    robin = int()

    def __init__(self):
        # _list_tmp = hbm_channel_mem()
        self.robin = int(0)
        # for iter in range(0, 32):
            # self.hbm_data_mem.append(_list_tmp)
        self.hbm_data_mem = [hbm_channel_mem() for x in range(0, 32)]
        # print("init")
        # for i in self.hbm_data_mem:
        #     print(i)
    
    def append_channel(self, channel : int, tunnel : int, value : int):
        self.hbm_data_mem[channel].append(tunnel, value)

    def appends_channel(self, channel : int, valuelist : list):
        self.hbm_data_mem[channel].appends(valuelist)

    def append(self, valuelist : list, zeropoint : int):
        # print('append enter')
        for channel in range(0, 32):
            # print(valuelist[channel])
            for value in valuelist[channel]:
                self.hbm_data_mem[channel].append(self.robin, value, zeropoint)

        if (self.robin == 31):
            self.robin = 0
        else:
            self.robin = self.robin + 1
        # print('append return')


    def autofill(self):
        while (self.robin != 0):
            for channel in range(0, 32):
                self.hbm_data_mem[channel].append(self.robin, 0) # TODO: change to real zero point

            if (self.robin == 31):
                self.robin = 0
            else:
                self.robin = self.robin + 1

    def print(self):
        for channel in range(0, 32):
            print(f'data in {channel}:')
            # print(self.hbm_data_mem[channel])
            self.hbm_data_mem[channel].print()

    def save(self, filepath : str):
        for channel in range(0, 32):
            self.hbm_data_mem[channel].save("{}_{}.bin".format(filepath, channel))


i_hbm_mem = hbm_mem()

for key in myModel.state_dict().keys():
    print(key)
    # print(layer_numpy.shape)
    
    if (key.find("weight") >= 0):
        pass
    else:
        continue
    
    if (myModel.state_dict()[key].dim() == 4):
        layer_numpy = torch.int_repr(myModel.state_dict()[key]).numpy()
        layer_zero_point = myModel.state_dict()[key].q_zero_point() # myModel.state_dict()[key[0:-6] + "zero_point"].numpy()
        layer_scale = myModel.state_dict()[key].q_scale() # myModel.state_dict()[key[0:-6] + "scale"].numpy()
        realzero = getrealzero(layer_scale, layer_zero_point)
        # print("zero_point: ", layer_zero_point)
        # print("scale: ", layer_scale)
        # print("zero_quan: ", realzero)
        
        print(f"dealing {key}")
        for kernel_base in range(0, layer_numpy.shape[0], 16): # tqdm.trange(0, layer_numpy.shape[0], 16): # 
            multi_channel_valuelist_tmp_channelgroup_1 = []
            multi_channel_valuelist_tmp_channelgroup_2 = []
            for kernel in range(kernel_base, kernel_base + 16):
                single_channel_valuelist_tmp_channelgroup_1 = []
                single_channel_valuelist_tmp_channelgroup_2 = []

                for channel_base in range(0, layer_numpy.shape[1], 16):
                    for channel in range(channel_base, channel_base + 8): # channelgroup 1
                        single_channel_valuelist_tmp = [] # one row 1*9
                        for width in range(layer_numpy.shape[2]): # probably height and width is in this order
                            for height in range(layer_numpy.shape[3]):
                                if (channel < layer_numpy.shape[1]): # in case of conv_1 has only 3 channel
                                    # single_channel_valuelist_tmp.append(per_channel_quantized_model.state_dict()[key][kernel][channel][width][height])
                                    single_channel_valuelist_tmp.append(layer_numpy[kernel][channel][width][height])
                                else:
                                    single_channel_valuelist_tmp.append(realzero) # change to real zero
                        single_channel_valuelist_tmp_channelgroup_1.extend(single_channel_valuelist_tmp) # extends to 8*9 through order is in reverse
                            

                    for channel in range(channel_base + 8, channel_base + 16): # channelgroup 2
                        single_channel_valuelist_tmp = [] # one row 1*9
                        for width in range(layer_numpy.shape[2]): # probably height and width is in this order
                            for height in range(layer_numpy.shape[3]):
                                if (channel < layer_numpy.shape[1]): # in case of conv_1 has only 3 channel
                                    # single_channel_valuelist_tmp.append(per_channel_quantized_model.state_dict()[key][kernel][channel][width][height])
                                    single_channel_valuelist_tmp.append(layer_numpy[kernel][channel][width][height])
                                else:
                                    single_channel_valuelist_tmp.append(realzero)  # change to real zero
                        single_channel_valuelist_tmp_channelgroup_2.extend(single_channel_valuelist_tmp) # extends to 8*9

                multi_channel_valuelist_tmp_channelgroup_1.append(single_channel_valuelist_tmp_channelgroup_1)
                multi_channel_valuelist_tmp_channelgroup_2.append(single_channel_valuelist_tmp_channelgroup_2)

            # print("multi_channel_valuelist_tmp_channelgroup_1")
            # for row in range(0, len(multi_channel_valuelist_tmp_channelgroup_1)):
            #     print(len(multi_channel_valuelist_tmp_channelgroup_1[row]), multi_channel_valuelist_tmp_channelgroup_1[row])
            # print("multi_channel_valuelist_tmp_channelgroup_2")
            # for row in range(0, len(multi_channel_valuelist_tmp_channelgroup_2)):
            #     print(len(multi_channel_valuelist_tmp_channelgroup_2[row]), multi_channel_valuelist_tmp_channelgroup_2[row])
            # print(multi_channel_valuelist_tmp_channelgroup_1)
            # print(multi_channel_valuelist_tmp_channelgroup_2)
            total_channel_valuelist_tmp = []
            total_channel_valuelist_tmp.extend(multi_channel_valuelist_tmp_channelgroup_1)
            total_channel_valuelist_tmp.extend(multi_channel_valuelist_tmp_channelgroup_2)
            i_hbm_mem.append(total_channel_valuelist_tmp, realzero)
            # i_hbm_mem.print()
            # break

    elif (myModel.state_dict()[key].dim() == 3):
        pass
    elif (myModel.state_dict()[key].dim() == 2):
        pass
    elif (myModel.state_dict()[key].dim() == 1):
        pass
    else:
        pass
        # raise

    # i_hbm_mem.print()
    # i_hbm_mem.autofill()
    # print('after auto fill')
    # i_hbm_mem.print()
    # break
i_hbm_mem.save("quan_bin_pertensor/quan_bin_pertensor")
# print('\n\nmodel_int8 features.0.weight:',per_channel_quantized_model.state_dict()['features.0.weight'])
# print('\n\nmodel_int8 features.0.bias:',per_channel_quantized_model.state_dict()['features.0.bias'])
# print('\n\nmodel_int8 features.0.scale:',per_channel_quantized_model.state_dict()['features.0.scale'])
# print('\n\nmodel_int8 features.0.zero_point:',per_channel_quantized_model.state_dict()['features.0.zero_point'])

# print('\n\nmodel_int8 features.2.weight:',per_channel_quantized_model.state_dict()['features.2.weight'])
# print('\n\nmodel_int8 features.2.bias:',per_channel_quantized_model.state_dict()['features.2.bias'])
# print('\n\nmodel_int8 features.2.scale:',per_channel_quantized_model.state_dict()['features.2.scale'])
# print('\n\nmodel_int8 features.2.zero_point:',per_channel_quantized_model.state_dict()['features.2.zero_point'])

# print('\n\nmodel_int8 quant.scale:',per_channel_quantized_model.state_dict()['quant.scale'])
# print('\n\nmodel_int8 quant.zero_point:',per_channel_quantized_model.state_dict()['quant.zero_point'])

features.0.weight
dealing features.0.weight
features.0.scale
features.0.zero_point
features.0.bias
features.2.weight
dealing features.2.weight
features.2.scale
features.2.zero_point
features.2.bias
features.5.weight
dealing features.5.weight
features.5.scale
features.5.zero_point
features.5.bias
features.7.weight
dealing features.7.weight
features.7.scale
features.7.zero_point
features.7.bias
features.10.weight
dealing features.10.weight
features.10.scale
features.10.zero_point
features.10.bias
features.12.weight
dealing features.12.weight
features.12.scale
features.12.zero_point
features.12.bias
features.14.weight
dealing features.14.weight
features.14.scale
features.14.zero_point
features.14.bias
features.17.weight
dealing features.17.weight
features.17.scale
features.17.zero_point
features.17.bias
features.19.weight
dealing features.19.weight
features.19.scale
features.19.zero_point
features.19.bias
features.21.weight
dealing features.21.weight
features.21.scale
features.21.zero_poi

* ### save scale and zero_point value

In [4]:


# return  n, p
def get_mo(M, P):
    last_error = None
    result = M * P

    for n in range(1, 16):
        Mo = int(round(2 ** n * M)) # 这里不一定要四舍五入截断，因为python定点数不好表示才这样处理

        approx_result = (Mo * P) >> n

        cur_error = result - approx_result
        if last_error == None:
            pass
        else:
            if (abs(last_error - cur_error) < 1e-1 and abs(cur_error) < 1) or (n == 15):
                yield n
                break
        last_error = cur_error

    # print("Mo=%d, approx=%f, error=%f"%\
    #     (Mo, approx_result, result-approx_result))
    
    yield Mo

    return

# M = 0.0072474273418460
# P = 7091
# ret_get_mo = iter(get_mo(M, P))
# print(next(ret_get_mo)) # 1
# print(next(ret_get_mo)) # 1


In [14]:
for key in myModel.state_dict().keys():
    print(key, myModel.state_dict()[key].shape)

features.0.weight torch.Size([64, 3, 3, 3])
features.0.scale torch.Size([])
features.0.zero_point torch.Size([])
features.0.bias torch.Size([64])
features.2.weight torch.Size([64, 64, 3, 3])
features.2.scale torch.Size([])
features.2.zero_point torch.Size([])
features.2.bias torch.Size([64])
features.5.weight torch.Size([128, 64, 3, 3])
features.5.scale torch.Size([])
features.5.zero_point torch.Size([])
features.5.bias torch.Size([128])
features.7.weight torch.Size([128, 128, 3, 3])
features.7.scale torch.Size([])
features.7.zero_point torch.Size([])
features.7.bias torch.Size([128])
features.10.weight torch.Size([256, 128, 3, 3])
features.10.scale torch.Size([])
features.10.zero_point torch.Size([])
features.10.bias torch.Size([256])
features.12.weight torch.Size([256, 256, 3, 3])
features.12.scale torch.Size([])
features.12.zero_point torch.Size([])
features.12.bias torch.Size([256])
features.14.weight torch.Size([256, 256, 3, 3])
features.14.scale torch.Size([])
features.14.zero_po

AttributeError: 'torch.dtype' object has no attribute 'shape'

In [26]:

    
# def getSandZ():
scale_chain = dict()
prev_scale_layer = None
zero_point_chain = dict()
prev_zero_point_layer = None
# fetch network order
for key in myModel.state_dict().keys():
    if 'scale' in key:
        if prev_scale_layer is not None:
            scale_chain[prev_scale_layer] = key
        prev_scale_layer = key 
    elif 'zero_point' in key:
        if prev_zero_point_layer is not None:
            zero_point_chain[prev_zero_point_layer] = key
        prev_zero_point_layer = key
    else:
        pass
# print(scale_chain)
# print(zero_point_chain)
input_layer_size_list = [(1, 3, 224 + 2, 224 + 2), (1, 64, 224 + 2, 224 + 2), 
                         (1, 64, 112 + 2, 112 + 2), (1, 128, 112 + 2, 112 + 2), 
                         (1, 128, 56 + 2, 56 + 2), (1, 256, 56 + 2, 56 + 2), (1, 256, 56 + 2, 56 + 2), 
                         (1, 256, 28 + 2, 28 + 2), (1, 512, 28 + 2, 28 + 2), (1, 512, 28 + 2, 28 + 2), 
                         (1, 512, 14 + 2, 14 + 2), (1, 512, 14 + 2, 14 + 2), (1, 512, 14 + 2, 14 + 2)]
input_counter = 0
P = 7091
for key in myModel.state_dict().keys():
    if 'features' in key and 'weight' in key:
        WeightScaleValue = myModel.state_dict()[key].q_scale()
        WeightZeroPointValue = myModel.state_dict()[key].q_zero_point()
        InputScaleValue = myModel.state_dict()[f'{key[0:-6]}scale'].item()
        InputZeroPointValue = myModel.state_dict()[f'{key[0:-6]}zero_point'].item()
        print(key, 'dealing') # , ' ', WeightScaleValue, ' ', WeightZeroPointValue, ' ', InputScaleValue, ' ', InputZeroPointValue)
        assert(WeightZeroPointValue == 0)
        NextInputScaleName = scale_chain[f'{key[0:-6]}scale']
        NextInputScaleValue = myModel.state_dict()[NextInputScaleName].item()
        NextInputZeroPointName = zero_point_chain[f'{key[0:-6]}zero_point']
        NextInputZeroPointValue = myModel.state_dict()[NextInputZeroPointName].item()
        M = WeightScaleValue * InputScaleValue / NextInputScaleValue
        assert(M > 0 and M < 1)
        # NandMo = iter(get_mo(M, P))
        # yield next(NandMo) # N
        # yield next(NandMo) # Mo
        x = myModel.state_dict()[key].dequantize()
        # bias = torch.ops.quantized.conv2d(myModel.state_dict()[key],myModel.state_dict()[f'{key[0:-6]}zero_point'], (3, 3), (1, 1), (1, 1), 0, 0, 1)
        sub_bias = M * myModel.state_dict()[f'{key[0:-6]}bias'] + NextInputZeroPointValue
        bias = torch.conv2d(torch.full(input_layer_size_list[input_counter], M * InputZeroPointValue), x, sub_bias, 1, 0, 1, 1) # no padding, padding zero has been quantified
        # print(x.shape, bias.shape, bias[0][0][0][0].item())
        bias_quantified_list = []
        bias_numpy = bias.detach().numpy()
        for channel in range(0, bias_numpy.shape[1]):
            bias_quantified_list.append(bias_numpy[0][channel][0][0]) # bias
        print(bias_quantified_list)
        # yield bias_quantified_list
        # print(myModel.state_dict()[f'{key[0:-6]}bias'])
        # print(NextInputZeroPointValue)
        # in_channels = 3
        # out_channels = 64
        # m = torch.nn.quantized.Conv2d(in_channels, out_channels, 3, stride=1, padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros')
        # bias = M * InputZeroPointValue * conv_result + M * myModel.state_dict()[f'{key[0:-6]}bias'] + NextInputZeroPointValue
        # bias = M * myModel.state_dict()[key] * InputZeroPointValue + M * myModel.state_dict()[f'{key[0:-6]}bias'] + NextInputZeroPointValue
        # print(f'M: {M} bias: {bias}')
        input_counter = input_counter + 1
    else:
        print(key, 'pass')
        pass
# return

# result = iter(getSandZ())
# next(result)

features.0.weight   0.009980898350477219   0   0.3041377365589142   62
[67.014824, 67.012505, 66.99944, 67.11255, 67.00365, 67.061, 66.9992, 67.20313, 67.012115, 66.98165, 66.80699, 66.99638, 66.67418, 66.99047, 66.99919, 66.99518, 66.95112, 66.986336, 66.97603, 67.00203, 66.99663, 67.00346, 67.08957, 66.99276, 67.01384, 66.98022, 67.006516, 66.897644, 66.99607, 66.88489, 67.004196, 67.00024, 67.00769, 66.998955, 67.0065, 66.77249, 66.990295, 66.81822, 66.971504, 67.000984, 66.998436, 67.16307, 67.00127, 67.062126, 66.755516, 66.865974, 66.99053, 67.00317, 66.988075, 66.975975, 67.030106, 66.99986, 67.22038, 66.99952, 67.17542, 67.00086, 67.00975, 67.00179, 66.99977, 67.017494, 66.88105, 66.97106, 66.94674, 66.98962]
features.0.scale pass
features.0.zero_point pass
features.0.bias pass
features.2.weight   0.004230252001434565   0   0.8425370454788208   67
[78.5343, 78.43863, 78.38007, 78.9694, 78.71542, 78.880165, 78.52515, 78.76662, 78.86514, 77.91949, 78.75966, 78.776344, 77.96307, 7

In [2]:
print('\n\nmodel_int8 conv keys:',myModel.state_dict().keys())
print('\n\nmodel_int8 features.0.weight:',myModel.state_dict()['features.0.weight'])
print('\n\nmodel_int8 features.0.bias:',myModel.state_dict()['features.0.bias'])
print('\n\nmodel_int8 features.0.scale:',myModel.state_dict()['features.0.scale'])
print('\n\nmodel_int8 features.0.zero_point:',myModel.state_dict()['features.0.zero_point'])

print('\n\nmodel_int8 features.2.weight:',myModel.state_dict()['features.2.weight'])
print('\n\nmodel_int8 features.2.bias:',myModel.state_dict()['features.2.bias'])
print('\n\nmodel_int8 features.2.scale:',myModel.state_dict()['features.2.scale'])
print('\n\nmodel_int8 features.2.zero_point:',myModel.state_dict()['features.2.zero_point'])

print('\n\nmodel_int8 quant.scale:',myModel.state_dict()['quant.scale'])
print('\n\nmodel_int8 quant.zero_point:',myModel.state_dict()['quant.zero_point'])



model_int8 conv keys: odict_keys(['features.0.weight', 'features.0.bias', 'features.0.scale', 'features.0.zero_point', 'features.2.weight', 'features.2.bias', 'features.2.scale', 'features.2.zero_point', 'features.5.weight', 'features.5.bias', 'features.5.scale', 'features.5.zero_point', 'features.7.weight', 'features.7.bias', 'features.7.scale', 'features.7.zero_point', 'features.10.weight', 'features.10.bias', 'features.10.scale', 'features.10.zero_point', 'features.12.weight', 'features.12.bias', 'features.12.scale', 'features.12.zero_point', 'features.14.weight', 'features.14.bias', 'features.14.scale', 'features.14.zero_point', 'features.17.weight', 'features.17.bias', 'features.17.scale', 'features.17.zero_point', 'features.19.weight', 'features.19.bias', 'features.19.scale', 'features.19.zero_point', 'features.21.weight', 'features.21.bias', 'features.21.scale', 'features.21.zero_point', 'features.24.weight', 'features.24.bias', 'features.24.scale', 'features.24.zero_point', '

In [3]:



conv_weight0 = myModel.state_dict()['features.0.weight']
conv_weight0.int_repr()
print('\n\nconv_weight0',conv_weight0.int_repr())

conv_weight2 = myModel.state_dict()['features.2.weight']
conv_weight2.int_repr()
print('\n\nconv_weight2',conv_weight2.int_repr())


print("Size of model after quantization")
print_size_of_model(myModel)

top1, top5 = evaluate(myModel, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(myModel), saved_model_dir + scripted_default_quantized_model_file) # save default_quantized model




conv_weight0 tensor([[[[ -55,   14,   53],
          [ -58,   36,   77],
          [ -69,   -5,   49]],

         [[  18,    1,   -8],
          [   4,   -7,  -26],
          [  13,  -17,  -13]],

         [[  31,  -17,  -43],
          [  48,   -8,  -49],
          [  63,    2,  -28]]],


        [[[  23,   13,   19],
          [ -43,  -24,   25],
          [ -25,   14,   -1]],

         [[ -14,  -22,   15],
          [ -84,  -35,   57],
          [ -24,   52,   54]],

         [[ -31,  -37,  -13],
          [ -47,  -16,   35],
          [   5,   59,   50]]],


        [[[  18,   52,    1],
          [ -27,  -72,   31],
          [  -8,  -22,   34]],

         [[  31,   67,    2],
          [ -47, -107,   34],
          [  -8,  -31,   55]],

         [[  32,   42,  -35],
          [   9,  -47,    1],
          [  11,  -15,   -2]]],


        ...,


        [[[   8,   13,    3],
          [  22,   25,   -5],
          [   5,    3,    2]],

         [[ -18,   -7,   -1],
          [  -

In [None]:

"""
------------------------------
    5. optimal
    ·Quantizes weights on a per-channel basis
    ·Uses a histogram observer that collects a histogram of activations and then picks quantization parameters
    in an optimal manner.
------------------------------
"""

per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
per_channel_quantized_model.eval()
# per_channel_quantized_model.fuse_model() # VGG dont need fuse
per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # set the quantize config
print('\n optimal quantize config: ')
print(per_channel_quantized_model.qconfig)

torch.quantization.prepare(per_channel_quantized_model, inplace=True) # execute the quantize config
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches) # calibrate
print("Calibrate done")

torch.quantization.convert(per_channel_quantized_model, inplace=True) # convert to quantize model
print('Post Training Optimal Quantization: Convert done')

print("Size of model after optimal quantization")
print_size_of_model(per_channel_quantized_model)

top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches) # test acc
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_optimal_quantized_model_file) # save quantized model


"""
------------------------------
    6. compare performance
------------------------------
"""

print("\nInference time compare: ")
run_benchmark(saved_model_dir + scripted_float_model_file, data_loader_test)
run_benchmark(saved_model_dir + scripted_default_quantized_model_file, data_loader_test)
run_benchmark(saved_model_dir + scripted_optimal_quantized_model_file, data_loader_test)

""" you can compare the model's size/accuracy/inference time.
    ----------------------------------------------------------------------------------------
                    | origin model | default quantized model | optimal quantized model
    model size:     |    553 MB    |         138 MB          |        138 MB
    test accuracy:  |    79.33     |         76.67           |        78.67
    inference time: |    317 ms    |         254 ms          |        257 ms
    ---------------------------------------------------------------------------------------
"""