In [1]:
"""
==========================
    Style: static quantize
    Model: VGG-16
    Create by: Han_yz @ 2020/1/29
    Email: 20125169@bjtu.edu.cn
    Github: https://github.com/Forggtensky
==========================
"""

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

In [2]:



"""
------------------------------
    1、Model architecture
------------------------------
"""

class VGG_fcn(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(VGG_fcn,self).__init__()
        # self.features = features  # 提取特征部分的网络，也为Sequential格式
        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(  # 分类部分的网络
            nn.Linear(512*7*7,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,num_classes)
        )
        # add the quantize part
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x = self.quant(x)
        # x = self.features(x)
        # x = self.avgpool(x)
        x = torch.flatten(x,start_dim=1)
        # x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module,nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias,0)
            elif isinstance(module,nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(module.bias,0)

class VGG(nn.Module):
    def __init__(self,features,num_classes=1000,init_weights=False):
        super(VGG,self).__init__()
        self.features = features  # 提取特征部分的网络，也为Sequential格式
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(  # 分类部分的网络
            nn.Linear(512*7*7,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,num_classes)
        )
        # add the quantize part
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x = self.quant(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x,start_dim=1)
        # x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module,nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias,0)
            elif isinstance(module,nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(module.bias,0)

cfgs = {
    'vgg11':[64,'M',128,'M',256,256,'M',512,512,'M',512,512,'M'],
    'vgg13':[64,64,'M',128,128,'M',256,256,'M',512,512,'M',512,512,'M'],
    'vgg16':[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M',512,512,512,'M'],
    'vgg19':[64,64,'M',128,128,'M',256,256,256,256,'M',512,512,512,512,'M',512,512,512,512,'M'],
}

def make_features(cfg:list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2,stride=2)]  #vgg采用的池化层均为2,2参数
        else:
            conv2d = nn.Conv2d(in_channels,v,kernel_size=3,padding=1)  #vgg卷积层采用的卷积核均为3,1参数
            layers += [conv2d,nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)  #非关键字的形式输入网络的参数

def vgg(model_name='vgg16',**kwargs):
    try:
        cfg = cfgs[model_name]
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(make_features(cfg),**kwargs)  # **kwargs为可变长度字典，保存多个输入参数
    return model

def vgg_fcn(model_name='vgg16_fcn',**kwargs):
    model = VGG_fcn(**kwargs)  # **kwargs为可变长度字典，保存多个输入参数
    return model


In [3]:

"""
------------------------------
    2、Helper functions
------------------------------
"""

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5


def run_benchmark(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # Run the scripted model on a few batches of images
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

def run_benchmark_fcn(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # Run the scripted model on a few batches of images
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            x = torch.rand([images.shape(0), 7, 7, 512], dtype=torch.float32)
            start = time.time()
            output = model(x)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

def load_model(model_file):
    model_name = "vgg16"
    model = vgg(model_name=model_name,num_classes=1000,init_weights=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def load_model_fcn(model_file):
    model_name = "vgg16_fcn"
    model = vgg_fcn(model_name=model_name,num_classes=1000,init_weights=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')



In [4]:
modelfcn_name = "vgg16_fcn"
modelfcn = vgg_fcn(model_name=modelfcn_name,num_classes=1000,init_weights=False)
modelfcn.eval()
torch.save(modelfcn.state_dict(), 'model/vgg16fcn_pretrained_float.pth')
# state_dict = torch.load('model/vgg16fcn_pretrained_float.pth')
# modelfcn.load_state_dict(state_dict)
# modelfcn.to('cpu')

In [5]:
# x = torch.ones([30, 7, 7, 512], dtype=torch.float32)
# output = modelfcn(x)
# print(output.shape)

In [6]:
# 
# modelfcn = load_model_fcn('model/vgg16fcn_pretrained_float.pth')
# print_size_of_model(modelfcn)


num_calibration_batches = 10

modelfcn = load_model_fcn('model/vgg16fcn_pretrained_float.pth').to('cpu')
modelfcn.eval()

torch.jit.save(torch.jit.script(modelfcn), "model/vgg16fcn_quantization_scripted.pth") # save un_quantized model

# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
modelfcn.qconfig = torch.quantization.default_qconfig
print(modelfcn.qconfig)
torch.quantization.prepare(modelfcn, inplace=True)

# Calibrate with the training set
print('\nPost Training Quantization Prepare: Inserting Observers by Calibrate')
# evaluate(modelfcn, criterion, data_loader, neval_batches=num_calibration_batches)
print("Calibrate done")


torch.quantization.convert(modelfcn, inplace=True)
print('Post Training Quantization: Convert done')


print('\n After quantization: \n',modelfcn)

print("Size of model after quantization")
print_size_of_model(modelfcn)

torch.jit.save(torch.jit.script(modelfcn), "model/vgg16fcn_quantization_scripted_default_quantized.pth") 

QConfig(activation=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

Post Training Quantization Prepare: Inserting Observers by Calibrate
Calibrate done




Post Training Quantization: Convert done

 After quantization: 
 VGG_fcn(
  (classifier): Sequential(
    (0): QuantizedLinear(in_features=25088, out_features=4096, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    (1): QuantizedReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): QuantizedLinear(in_features=4096, out_features=4096, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    (4): QuantizedReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): QuantizedLinear(in_features=4096, out_features=1000, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
  )
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)
Size of model after quantization
Size (MB): 123.673008


In [7]:

"""
------------------------------
    2、Helper functions
------------------------------
"""

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5


def run_benchmark(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # Run the scripted model on a few batches of images
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

def run_benchmark_fcn(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # Run the scripted model on a few batches of images
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            x = torch.rand([images.shape[0], 7, 7, 512], dtype=torch.float32)
            start = time.time()
            output = model(x)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

def load_model(model_file):
    model_name = "vgg16"
    model = vgg(model_name=model_name,num_classes=1000,init_weights=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def load_model_fcn(model_file):
    model_name = "vgg16_fcn"
    model = vgg_fcn(model_name=model_name,num_classes=1000,init_weights=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')



In [8]:

"""
------------------------------
    3. Define dataset and data loaders
------------------------------
"""

def prepare_data_loaders(data_path):
    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    print("dataset_train : %d" % (len(dataset)))

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    print("dataset_test : %d" % (len(dataset_test)))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

# Specify random seed for repeatable results
torch.manual_seed(191009)

data_path = 'data/imagenet_1k'
saved_model_dir = 'model/'
float_model_file = 'vgg16_pretrained_float.pth'
scripted_float_model_file = 'vgg16_quantization_scripted.pth'
scripted_default_quantized_model_file = 'vgg16_quantization_scripted_default_quantized.pth'
scripted_optimal_quantized_model_file = 'vgg16_quantization_scripted_optimal_quantized.pth'

train_batch_size = 30
eval_batch_size = 30

data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to('cpu')

print('\n Before quantization: \n',float_model)

float_model.eval()
# Note: vgg-16 has no BN layer so that not need to fuse model

num_eval_batches = 10

print("Size of baseline model")
print_size_of_model(float_model)

# to get a “baseline” accuracy, see the accuracy of our un-quantized model
top1, top5 = evaluate(float_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file) # save un_quantized model


dataset_train : 1000
dataset_test : 1000

 Before quantization: 
 VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)


In [9]:

"""
------------------------------
    4. Post-training static quantization
------------------------------
"""

num_calibration_batches = 10

myModel = load_model(saved_model_dir + float_model_file).to('cpu')
myModel.eval()

# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
myModel.qconfig = torch.quantization.default_qconfig
print(myModel.qconfig)
torch.quantization.prepare(myModel, inplace=True)

# Calibrate with the training set
print('\nPost Training Quantization Prepare: Inserting Observers by Calibrate')
evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
print("Calibrate done")

# Convert to quantized model
torch.quantization.convert(myModel, inplace=True)
print('Post Training Quantization: Convert done')


print('\n After quantization: \n',myModel)

print("Size of model after quantization")
print_size_of_model(myModel)

top1, top5 = evaluate(myModel, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(myModel), saved_model_dir + scripted_default_quantized_model_file) # save default_quantized model


QConfig(activation=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

Post Training Quantization Prepare: Inserting Observers by Calibrate
..........Calibrate done
Post Training Quantization: Convert done

 After quantization: 
 VGG(
  (features): Sequential(
    (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.3200511634349823, zero_point=63, padding=(1, 1))
    (1): QuantizedReLU(inplace=True)
    (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.8684448599815369, zero_point=68, padding=(1, 1))
    (3): QuantizedReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.3496954441070557, zero_point=88, padding=(1, 1))
    (6): QuantizedReLU(inplace

In [10]:

"""
------------------------------
    5. optimal
    ·Quantizes weights on a per-channel basis
    ·Uses a histogram observer that collects a histogram of activations and then picks quantization parameters
    in an optimal manner.
------------------------------
"""
float_model_file = 'vgg16_pretrained_float.pth'
criterion = nn.CrossEntropyLoss()

per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
per_channel_quantized_model.eval()
# per_channel_quantized_model.fuse_model() # VGG dont need fuse
per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # set the quantize config
print('\n optimal quantize config: ')
print(per_channel_quantized_model.qconfig)

torch.quantization.prepare(per_channel_quantized_model, inplace=True) # execute the quantize config
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches) # calibrate
print("Calibrate done")

torch.quantization.convert(per_channel_quantized_model, inplace=True) # convert to quantize model
print('Post Training Optimal Quantization: Convert done')

print("Size of model after optimal quantization")
print_size_of_model(per_channel_quantized_model)

top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches) # test acc
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_optimal_quantized_model_file) # save quantized model




 optimal quantize config: 
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
..........Calibrate done
Post Training Optimal Quantization: Convert done
Size of model after optimal quantization
Size (MB): 138.626255
..........Evaluation accuracy on 300 images, 78.33


### save quantized bin

In [12]:
for key in per_channel_quantized_model.state_dict().keys():
    print(key)

features.0.weight
features.0.scale
features.0.zero_point
features.0.bias
features.2.weight
features.2.scale
features.2.zero_point
features.2.bias
features.5.weight
features.5.scale
features.5.zero_point
features.5.bias
features.7.weight
features.7.scale
features.7.zero_point
features.7.bias
features.10.weight
features.10.scale
features.10.zero_point
features.10.bias
features.12.weight
features.12.scale
features.12.zero_point
features.12.bias
features.14.weight
features.14.scale
features.14.zero_point
features.14.bias
features.17.weight
features.17.scale
features.17.zero_point
features.17.bias
features.19.weight
features.19.scale
features.19.zero_point
features.19.bias
features.21.weight
features.21.scale
features.21.zero_point
features.21.bias
features.24.weight
features.24.scale
features.24.zero_point
features.24.bias
features.26.weight
features.26.scale
features.26.zero_point
features.26.bias
features.28.weight
features.28.scale
features.28.zero_point
features.28.bias
classifier.0.sc

In [26]:
# print('\n\nmodel_int8 conv keys:',per_channel_quantized_model.state_dict().keys())
# import tqdm
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

def getrealzero(_scale : float, _zero_point : int) -> int:
    return int(torch.int_repr(torch.quantize_per_tensor(torch.tensor(0.), float(_scale), int(_zero_point), torch.qint8)).numpy())
    # return round(-_zero_point / _scale)

class hbm_channel_mem(): # each channel save 32 bits data
    data_mem = []
    zeropoint_mem = []
    def __init__(self):
        # _list_tmp = list()
        # for iter in range(0, 32):
            # self.data_mem.append(_list_tmp)
        self.data_mem = [[] for x in range(0, 32)]
        self.zeropoint_mem = [[] for x in range(0, 32)]

    # def set(self, tunnel, addr, value):
    #     data_mem
    def autofill(self):
        for position in range(1, 32):
            for addr in range(len(self.data_mem[position]), len(self.data_mem[0])):
                self.data_mem[position].append(self.zeropoint_mem[0][addr])
            # while (len(self.data_mem[0]) > len(self.data_mem[position])):
            #     self.data_mem[position].append(self.zeropoint_mem[0][len(self.data_mem[position])]) # change to real zero point

    def print(self, maxaddrlimit = 64):
        maxaddr = 0
        for position in range(0, 32):
            maxaddr = max(maxaddr, len(self.data_mem[position]))

        if maxaddrlimit != -1:
            maxaddr = min(maxaddr, maxaddrlimit)

        for addr in range(0, maxaddr):
            for position in range(0, 32):
                if len(self.data_mem[position]) <= addr:
                    print('NUL\t', end = '')
                else:
                    print(f'{self.data_mem[position][addr]}\t', end = '')
            print()

    def append(self, position : int, value : int, zeropoint : int):
        # print('channel append enter')
        # print(self)
        # print(position, value)
        # print('channel append exit')
        self.data_mem[position].append(value)
        self.zeropoint_mem[position].append(zeropoint)

    def appends(self, valuelist : list):
        for iter in range(0, 32):
            self.data_mem[iter].append(valuelist[iter])

    def save(self, filepath : str, saveaspcieorder = True):
        self.autofill()
        # self.print()
        if saveaspcieorder:
            with open(filepath, 'wb+') as f:
                for addr_base in range(0, len(self.data_mem[0]), 8):
                    for addr in range(addr_base + 7, addr_base - 1, -1):
                        for position in range(0, 32):
                            # print(type(self.data_mem[addr][position]))
                            try:
                                # print(addr, ' ', position)
                                # print(self.data_mem[position][addr])
                                # print(int(self.data_mem[position][addr]))
                                # print(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                                f.write(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                            except:
                                
                                print(self.data_mem[position][addr])
                                raise
                        # f.write(int.from_bytes(self.data_mem[addr][position].to_bytes(1, 'little', signed = True), 'little', signed = False))
                    # raise
                    
                f.close()
        else:
            with open(filepath, 'wb+') as f:
                for addr in range(len(self.data_mem[0])):
                    
                    for position in range(31, -1, -1):
                        # print(type(self.data_mem[addr][position]))
                        try:
                            # print(self.data_mem[position][addr])
                            # print(int(self.data_mem[position][addr]))
                            # print(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                            f.write(int(self.data_mem[position][addr]).to_bytes(1, 'little', signed = True))
                        except:
                            
                            print(self.data_mem[position][addr])
                            raise
                        # f.write(int.from_bytes(self.data_mem[addr][position].to_bytes(1, 'little', signed = True), 'little', signed = False))
                    # raise
                    
                f.close()

class hbm_mem():
    hbm_data_mem = []
    robin = int()

    def __init__(self):
        # _list_tmp = hbm_channel_mem()
        self.robin = int(0)
        # for iter in range(0, 32):
            # self.hbm_data_mem.append(_list_tmp)
        self.hbm_data_mem = [hbm_channel_mem() for x in range(0, 32)]
        # print("init")
        # for i in self.hbm_data_mem:
        #     print(i)
    
    def append_channel(self, channel : int, tunnel : int, value : int):
        self.hbm_data_mem[channel].append(tunnel, value)

    def appends_channel(self, channel : int, valuelist : list):
        self.hbm_data_mem[channel].appends(valuelist)

    def append(self, valuelist : list, zeropoint : int):
        # print('append enter')
        for channel in range(0, 32):
            # print(valuelist[channel])
            for value in valuelist[channel]:
                self.hbm_data_mem[channel].append(self.robin, value, zeropoint)

        if (self.robin == 31):
            self.robin = 0
        else:
            self.robin = self.robin + 1
        # print('append return')


    def autofill(self):
        while (self.robin != 0):
            for channel in range(0, 32):
                self.hbm_data_mem[channel].append(self.robin, 0) # TODO: change to real zero point

            if (self.robin == 31):
                self.robin = 0
            else:
                self.robin = self.robin + 1

    def print(self):
        for channel in range(0, 32):
            print(f'data in {channel}:')
            # print(self.hbm_data_mem[channel])
            self.hbm_data_mem[channel].print()

    def save(self, filepath : str):
        for channel in range(0, 32):
            self.hbm_data_mem[channel].save("{}_{}.bin".format(filepath, channel))


i_hbm_mem = hbm_mem()

for key in per_channel_quantized_model.state_dict().keys():
    print(key)
    # print(layer_numpy.shape)
    
    if (key.find("weight") >= 0):
        pass
    else:
        continue
    
    if (per_channel_quantized_model.state_dict()[key].dim() == 4):
        layer_numpy = torch.int_repr(per_channel_quantized_model.state_dict()[key]).numpy()
        layer_zero_point = per_channel_quantized_model.state_dict()[key[0:-6] + "zero_point"].numpy()
        layer_scale = per_channel_quantized_model.state_dict()[key[0:-6] + "scale"].numpy()
        realzero = getrealzero(layer_scale, layer_zero_point)
        # print("zero_point: ", layer_zero_point)
        # print("scale: ", layer_scale)
        # print("zero_quan: ", realzero)
        
        print(f"dealing {key}")
        for kernel_base in range(0, layer_numpy.shape[0], 16): # tqdm.trange(0, layer_numpy.shape[0], 16): # 
            multi_channel_valuelist_tmp_channelgroup_1 = []
            multi_channel_valuelist_tmp_channelgroup_2 = []
            for kernel in range(kernel_base, kernel_base + 16):
                single_channel_valuelist_tmp_channelgroup_1 = []
                single_channel_valuelist_tmp_channelgroup_2 = []

                for channel_base in range(0, layer_numpy.shape[1], 16):
                    for channel in range(channel_base, channel_base + 8): # channelgroup 1
                        single_channel_valuelist_tmp = [] # one row 1*9
                        for width in range(layer_numpy.shape[2]): # probably height and width is in this order
                            for height in range(layer_numpy.shape[3]):
                                if (channel < layer_numpy.shape[1]): # in case of conv_1 has only 3 channel
                                    # single_channel_valuelist_tmp.append(per_channel_quantized_model.state_dict()[key][kernel][channel][width][height])
                                    single_channel_valuelist_tmp.append(layer_numpy[kernel][channel][width][height])
                                else:
                                    single_channel_valuelist_tmp.append(realzero) # change to real zero
                        single_channel_valuelist_tmp_channelgroup_1.extend(single_channel_valuelist_tmp) # extends to 8*9 through order is in reverse
                            

                    for channel in range(channel_base + 8, channel_base + 16): # channelgroup 2
                        single_channel_valuelist_tmp = [] # one row 1*9
                        for width in range(layer_numpy.shape[2]): # probably height and width is in this order
                            for height in range(layer_numpy.shape[3]):
                                if (channel < layer_numpy.shape[1]): # in case of conv_1 has only 3 channel
                                    # single_channel_valuelist_tmp.append(per_channel_quantized_model.state_dict()[key][kernel][channel][width][height])
                                    single_channel_valuelist_tmp.append(layer_numpy[kernel][channel][width][height])
                                else:
                                    single_channel_valuelist_tmp.append(realzero)  # change to real zero
                        single_channel_valuelist_tmp_channelgroup_2.extend(single_channel_valuelist_tmp) # extends to 8*9

                multi_channel_valuelist_tmp_channelgroup_1.append(single_channel_valuelist_tmp_channelgroup_1)
                multi_channel_valuelist_tmp_channelgroup_2.append(single_channel_valuelist_tmp_channelgroup_2)

            # print("multi_channel_valuelist_tmp_channelgroup_1")
            # for row in range(0, len(multi_channel_valuelist_tmp_channelgroup_1)):
            #     print(len(multi_channel_valuelist_tmp_channelgroup_1[row]), multi_channel_valuelist_tmp_channelgroup_1[row])
            # print("multi_channel_valuelist_tmp_channelgroup_2")
            # for row in range(0, len(multi_channel_valuelist_tmp_channelgroup_2)):
            #     print(len(multi_channel_valuelist_tmp_channelgroup_2[row]), multi_channel_valuelist_tmp_channelgroup_2[row])
            # print(multi_channel_valuelist_tmp_channelgroup_1)
            # print(multi_channel_valuelist_tmp_channelgroup_2)
            total_channel_valuelist_tmp = []
            total_channel_valuelist_tmp.extend(multi_channel_valuelist_tmp_channelgroup_1)
            total_channel_valuelist_tmp.extend(multi_channel_valuelist_tmp_channelgroup_2)
            i_hbm_mem.append(total_channel_valuelist_tmp, realzero)
            # i_hbm_mem.print()
            # break

    elif (per_channel_quantized_model.state_dict()[key].dim() == 3):
        pass
    elif (per_channel_quantized_model.state_dict()[key].dim() == 2):
        pass
    elif (per_channel_quantized_model.state_dict()[key].dim() == 1):
        pass
    else:
        pass
        # raise

    # i_hbm_mem.print()
    # i_hbm_mem.autofill()
    # print('after auto fill')
    # i_hbm_mem.print()
    # break
i_hbm_mem.save("quan_bin/quan_bin")
# print('\n\nmodel_int8 features.0.weight:',per_channel_quantized_model.state_dict()['features.0.weight'])
# print('\n\nmodel_int8 features.0.bias:',per_channel_quantized_model.state_dict()['features.0.bias'])
# print('\n\nmodel_int8 features.0.scale:',per_channel_quantized_model.state_dict()['features.0.scale'])
# print('\n\nmodel_int8 features.0.zero_point:',per_channel_quantized_model.state_dict()['features.0.zero_point'])

# print('\n\nmodel_int8 features.2.weight:',per_channel_quantized_model.state_dict()['features.2.weight'])
# print('\n\nmodel_int8 features.2.bias:',per_channel_quantized_model.state_dict()['features.2.bias'])
# print('\n\nmodel_int8 features.2.scale:',per_channel_quantized_model.state_dict()['features.2.scale'])
# print('\n\nmodel_int8 features.2.zero_point:',per_channel_quantized_model.state_dict()['features.2.zero_point'])

# print('\n\nmodel_int8 quant.scale:',per_channel_quantized_model.state_dict()['quant.scale'])
# print('\n\nmodel_int8 quant.zero_point:',per_channel_quantized_model.state_dict()['quant.zero_point'])

features.0.weight
dealing features.0.weight
features.0.scale
features.0.zero_point
features.0.bias
features.2.weight
dealing features.2.weight
features.2.scale
features.2.zero_point
features.2.bias
features.5.weight
dealing features.5.weight
features.5.scale
features.5.zero_point
features.5.bias
features.7.weight
dealing features.7.weight
features.7.scale
features.7.zero_point
features.7.bias
features.10.weight
dealing features.10.weight
features.10.scale
features.10.zero_point
features.10.bias
features.12.weight
dealing features.12.weight
features.12.scale
features.12.zero_point
features.12.bias
features.14.weight
dealing features.14.weight
features.14.scale
features.14.zero_point
features.14.bias
features.17.weight
dealing features.17.weight
features.17.scale
features.17.zero_point
features.17.bias
features.19.weight
dealing features.19.weight
features.19.scale
features.19.zero_point
features.19.bias
features.21.weight
dealing features.21.weight
features.21.scale
features.21.zero_poi

In [26]:
for layer in per_channel_quantized_model.state_dict().keys():
    print(torch.int_repr(per_channel_quantized_model.state_dict()[layer]))
    # print(per_channel_quantized_model.state_dict()[layer])
    break

tensor([[[[ -92,   24,   88],
          [ -97,   59,  127],
          [-115,   -8,   81]],

         [[  29,    2,  -14],
          [   7,  -12,  -43],
          [  22,  -29,  -22]],

         [[  52,  -28,  -71],
          [  79,  -14,  -81],
          [ 105,    3,  -46]]],


        [[[  35,   19,   28],
          [ -65,  -37,   37],
          [ -38,   21,   -1]],

         [[ -21,  -33,   23],
          [-128,  -53,   85],
          [ -37,   79,   82]],

         [[ -48,  -56,  -20],
          [ -71,  -23,   52],
          [   8,   89,   75]]],


        [[[  21,   62,    1],
          [ -32,  -85,   37],
          [  -9,  -26,   40]],

         [[  37,   80,    2],
          [ -56, -128,   40],
          [ -10,  -36,   65]],

         [[  38,   50,  -42],
          [  10,  -55,    1],
          [  12,  -17,   -2]]],


        ...,


        [[[  40,   66,   17],
          [ 114,  127,  -24],
          [  24,   15,    9]],

         [[ -95,  -35,   -4],
          [ -25,    4,  -67],

### load quantized bin

In [27]:
int.from_bytes(b'\xfc', 'little', signed = True)

-4

In [31]:
with open("quan_bin/quan_bin_0.bin", "rb") as f:
    a = f.read(512)
    print(a)
    

b'\x06\x0b\x00\x1b\x0c%\xff\xf5\xfd\x08\xe6\xfa\xee*\xf4\x1b\xfc\xec\x0e\n\xf1\x12#\xd9\x05\x0b\xd9\xe9\x0f\xcb\xfc\xa4\xef\x06\x1a,\x0c\x1e\xf1\x01\x17\t\xe9\xe6\xe4\x0b\xe0\xeb\xdf\xe6\xf5\xf9\x0e\x1a)\x07\xf7\x087\xb5\xd7\x81\xf5\x18\xe2\x1c;C\x10\xfc\xf8\x14\x00\x01\xf7\xf1\xeb\xfc\xe9\xf5\xf8\xf7\xf6\xdd\xfe"#\xe7\xdf\x04W\x9b\xac\xdf\xefX\t\xef\x01\x1f\x10!\xf8\xed\x0b\xe6\x00\xe8\xe4\x18\xec\xef\x05\xef\x15\x10\xfd\n\xf0\xe0\x06\xff\x01\x05\xef\x14\x1e\x9f\xef\xe9\xfaA\x1c\x01\xec\xe0\x15\xfc\x03\xf5\xe7\xf5\xe8\xfc\xd6\xe6\xea\x054\xff\xeb\xdd\xf7\x15[\xc0\xb2\xce1;\xf8\x0b\x04U\x1b0\x03\xf7\x03\xea\xf6\xe8\xda\xf4\xdf\t\x06\xea\xeb\xe8\x13\xf7\x0b\xb1\xe1\xfbt\x81\xbf\x14\xf1\x7f\xeb\xef\x02\x16\xfe\xfc\xf3\xcb\xff\xf8\x11\xf3\x04\x02\xff\x04\x01\xf7\x10\x14\xfb\x04\xf0\xf7\x15\xfe\x07\x18\'M\xee\x8d\xe6\xda\xe9\x1a\x01\x0c\xf0\xc9\xfc\x14\x19\x18\xfa\xed\xd6\x14\xee\xf1\xf0\t\xdf\xfd\xea\xde\x05\x0cH\xce\xdbY\x06\xf8\xe1\t\x0f\x0f\x0c\x12\r\xf9\x0f\xe8"\xfd\xeb\xe2\xdf(\x03\x

In [None]:
with open("quan_bin/quan_bin_0.bin", "rb") as f:
    a = f.read(128)
    print(a)
    a = f.read(128)
    print(a)

    f.close()
    

b'\x9f\x1e\x14\xef\x05\x01\xff\x06\xe0\xf0\n\xfd\x10\x15\xef\x05\xef\xec\x18\xe4\xe8\x00\xe6\x0b\xed\xf8!\x10\x1f\x01\xef\tX\xef\xdf\xac\x9bW\x04\xdf\xe7#"\xfe\xdd\xf6\xf7\xf8\xf5\xe9\xfc\xeb\xf1\xf7\x01\x00\x14\xf8\xfc\x10C;\x1c\xe2\x18\xf5\x81\xd7\xb57\x08\xf7\x07)\x1a\x0e\xf9\xf5\xe6\xdf\xeb\xe0\x0b\xe4\xe6\xe9\t\x17\x01\xf1\x1e\x0c,\x1a\x06\xef\xa4\xfc\xcb\x0f\xe9\xd9\x0b\x05\xd9#\x12\xf1\n\x0e\xec\xfc\x1b\xf4*\xee\xfa\xe6\x08\xfd\xf5\xff%\x0c\x1b\x00\x0b\x06\xf8\x06Y\xdb\xceH\x0c\x05\xde\xea\xfd\xdf\t\xf0\xf1\xee\x14\xd6\xed\xfa\x18\x19\x14\xfc\xc9\xf0\x0c\x01\x1a\xe9\xda\xe6\x8d\xeeM\'\x18\x07\xfe\x15\xf7\xf0\x04\xfb\x14\x10\xf7\x01\x04\xff\x02\x04\xf3\x11\xf8\xff\xcb\xf3\xfc\xfe\x16\x02\xef\xeb\x7f\xf1\x14\xbf\x81t\xfb\xe1\xb1\x0b\xf7\x13\xe8\xeb\xea\x06\t\xdf\xf4\xda\xe8\xf6\xea\x03\xf7\x030\x1bU\x04\x0b\xf8;1\xce\xb2\xc0[\x15\xf7\xdd\xeb\xff4\x05\xea\xe6\xd6\xfc\xe8\xf5\xe7\xf5\x03\xfc\x15\xe0\xec\x01\x1cA\xfa\xe9\xef'
b'\xf2Ax\x16\xd9\xc0\xee\xa7\xed\x07\x17\x0c\n\xed\x13\x03

In [None]:
print("\nInference time compare: ")
run_benchmark_fcn('model/vgg16fcn_quantization_scripted_default_quantized.pth', data_loader_test)
run_benchmark_fcn('model/vgg16fcn_quantization_scripted.pth', data_loader_test)
# run_benchmark_fcn(saved_model_dir + scripted_optimal_quantized_model_file, data_loader_test)

In [None]:

"""
------------------------------
    6. compare performance
------------------------------
"""

print("\nInference time compare: ")
run_benchmark(saved_model_dir + scripted_float_model_file, data_loader_test)
run_benchmark(saved_model_dir + scripted_default_quantized_model_file, data_loader_test)
run_benchmark(saved_model_dir + scripted_optimal_quantized_model_file, data_loader_test)

""" you can compare the model's size/accuracy/inference time.
    ----------------------------------------------------------------------------------------
                    | origin model | default quantized model | optimal quantized model
    model size:     |    553 MB    |         138 MB          |        138 MB
    test accuracy:  |    79.33     |         76.67           |        78.67
    inference time: |    317 ms    |         254 ms          |        257 ms
    ---------------------------------------------------------------------------------------
"""