In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision.utils import _log_api_usage_once

In [2]:
import random
import numpy as np

def manual_seed(seed):
    np.random.seed(seed) #1
    random.seed(seed) #2
    torch.manual_seed(seed) #3
    torch.cuda.manual_seed(seed) #4.1
    torch.cuda.manual_seed_all(seed) #4.2
    torch.backends.cudnn.benchmark = False #5 
    torch.backends.cudnn.deterministic = True #6

manual_seed(42)

In [3]:
class Quantized(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        _log_api_usage_once(self)
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        pass


In [4]:
# filepath = './checkpoint/vgg16.pth'
# checkpoint = torch.load(filepath)
# model = checkpoint['model']
# print(model)
# model.load_state_dict(checkpoint['model_state_dict'])
model = torchvision.models.vgg.vgg16(pretrained=True)
model = Quantized(model)



In [5]:
modules = []
before_l = []
after_l = []
hooks = []

def hook_fn(module, input, output):
    modules.append(module)
    before_l.append(input[0])
    after_l.append(output)

def add_forward_hook(net, hooks):
    for name, layer in net._modules.items():
        if isinstance(layer, nn.Sequential) or isinstance(layer, torchvision.models.vgg.VGG):
            add_forward_hook(layer, hooks)
        else:
            hook = layer.register_forward_hook(hook_fn)
            hooks.append(hook)
            
    return hooks

def remove_forward_hook(hooks):
    for i in hooks:
        i.remove()
# out = model((torch.randn(1,3,32,32)))

In [6]:
model.to('cpu')
model.eval()
model.fuse_model()
backend = "x86"
qconfig = torch.ao.quantization.get_default_qconfig(backend)
model.qconfig = qconfig
print(f"defualt qconfig ; {model.qconfig}")
torch.backends.quantized.engine = backend
print(torch.backends.quantized.engine)
model_static_quantized = torch.ao.quantization.prepare(model, inplace = False)
# calibration
for _ in range(10):
    sample = (torch.randn(1,3,224,224))
    model_static_quantized(sample)
# sample = (torch.randint(0,255,(1,3,32,32),dtype=torch.uint8))/255
torch.ao.quantization.convert(model_static_quantized, inplace = True) 

## test, quantized modules, input, output
# make hook
hooks = add_forward_hook(model_static_quantized, hooks)
model_static_quantized(sample)
print(len(hooks), len(modules), len(before_l), len(after_l))
# remove hook, hook works at once
remove_forward_hook(hooks)
# for _ in range(5):
#     sample = (torch.randn(1,3,224,224))
#     model_static_quantized(sample)
# print(len(modules), len(before_l), len(after_l))
print(model_static_quantized)

defualt qconfig ; QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
x86




41 41 41 41
41 41 41
Quantized(
  (model): VGG(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.1237657368183136, zero_point=59, padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.38565802574157715, zero_point=65, padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.6498428583145142, zero_point=80, padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.8637077808380127, zero_point=62, padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.9496995806694031, zero_point=72, padding

In [7]:
from torchvision import transforms
from tqdm.auto import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop((224,224)),
    transforms.ToTensor(),
    # transforms.PILToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std= (0.229, 0.224, 0.225)),
])

val_data = torchvision.datasets.ImageNet(root="./dataset/ImageNet", split="val", transform=test_transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=32,
                                          shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import psutil

def memory_usage(message = "debug"):
    p = psutil.Process()
    rss = p.memory_info().rss / 2**20
    print(f"{message} memory usage : {rss:10.5f} MB")

def test(model, test_loader):
    model.eval()
    model.to('cpu')
    test_acc = 0
    with torch.no_grad():
        for i,data in enumerate(tqdm(test_loader,leave=True)):
            imgs, target = data[0], data[1]
            # imgs, target = data[0].to(device), data[1].to(device)
            output = model(imgs)
            _, preds = torch.max(output.data, 1)
            test_acc += (preds==target).detach().sum().item()
            # if i % 10 == 0:
            #     memory_usage()

    test_acc = 100. * test_acc/len(test_loader.dataset)
    

    return test_acc

test_acc = test(model_static_quantized, val_loader)
print(f"test result : {test_acc:.2f}")

  0%|          | 0/1563 [00:00<?, ?it/s]

100%|██████████| 1563/1563 [23:01<00:00,  1.13it/s]

test result : 51.45





In [9]:
for i in modules:
    print(type(i))
print(len(modules), len(before_l), len(after_l))
for i in range(len(modules)):
    if not isinstance(type(modules[i]), torch.quantization.QuantStub) and not isinstance(type(modules[i]), torch.quantization.DeQuantStub):
        print(type(modules[i]))
        print(f"before : {type(before_l[i])}, {before_l[i].dtype}")
        print(f"after : {type(after_l[i])}, {after_l[i].dtype}")
        break

print(before_l[1])
print(after_l[1])

<class 'torch.ao.nn.quantized.modules.Quantize'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<cla

In [10]:
print(type(modules[2]))
print(before_l[2].dtype,before_l[2].int_repr())
print(after_l[2].dtype,after_l[2])

<class 'torch.nn.modules.activation.ReLU'>
torch.quint8 tensor([[[[ 59,  59,  59,  ...,  59,  68,  80],
          [ 59,  78,  61,  ...,  59,  65,  66],
          [ 59,  87,  59,  ...,  59,  66,  81],
          ...,
          [ 67,  70,  75,  ...,  59,  92,  65],
          [ 61,  59,  71,  ...,  59,  94,  59],
          [ 60,  64,  64,  ...,  77,  90,  59]],

         [[ 59,  66,  59,  ...,  59,  61,  79],
          [ 80,  77,  71,  ...,  59,  59,  75],
          [ 68,  59,  62,  ...,  65,  83,  75],
          ...,
          [ 69,  59,  59,  ...,  75,  59,  59],
          [ 59,  74,  67,  ...,  67,  59,  72],
          [ 59,  76,  60,  ...,  59,  59,  59]],

         [[ 79,  70,  63,  ...,  59,  88,  81],
          [ 59,  60,  70,  ...,  59,  65,  67],
          [ 59,  59,  59,  ...,  83,  65,  59],
          ...,
          [ 59,  59,  59,  ...,  59,  61,  62],
          [ 68, 105,  75,  ...,  60,  83,  59],
          [ 70,  62,  59,  ...,  59,  66,  59]],

         ...,

         [[ 59

In [11]:
print(type(modules[-1]))
print(before_l[-1].dtype,before_l[-1])
print(after_l[-1].dtype,after_l[-1])

<class 'torch.ao.nn.quantized.modules.DeQuantize'>
torch.quint8 tensor([[-0.1601,  3.0425, -1.6013, -1.1209,  0.5605,  2.2418,  2.3219,  0.9608,
          0.2402,  0.0801, -0.0801,  1.8415,  0.8807,  0.4804,  0.8006,  1.2010,
         -2.2418, -1.5212,  1.5212, -0.9608, -1.2010, -1.5212,  0.2402,  0.1601,
         -0.9608, -1.3611,  0.2402, -0.3203, -1.1209,  0.0801, -1.4412, -1.5212,
         -2.4019, -0.4003, -0.1601, -0.7206, -0.1601, -1.5212, -1.3611, -1.5212,
          0.0000, -2.1617, -1.0408, -1.2810, -1.9215, -0.1601,  0.4003, -1.4412,
         -0.8807, -0.8006,  0.4804, -2.2418, -1.0408, -1.8415, -1.9215, -1.4412,
         -3.0425, -2.9624,  0.2402, -1.6013, -0.9608, -2.8823, -2.8823, -1.4412,
         -1.0408,  1.0408, -2.3219, -1.5212, -3.5228, -0.7206,  0.8006,  0.4003,
          0.4003,  0.9608,  1.7614,  0.8006,  0.4003,  1.8415,  2.2418,  2.0817,
         -0.3203, -0.6405, -0.3203, -0.6405,  0.6405,  0.0801,  1.1209, -1.6814,
          1.7614,  0.0801, -0.0801,  0.0000, 

In [46]:
for child in model_static_quantized.children():
    if isinstance(child, nn.Sequential) or isinstance(child, torchvision.models.vgg.VGG):
        print("---")
        for n,c in child.named_children():
            print("===")
            for name, param in c.named_parameters():
                print(name)

state = model.state_dict()
for names in model_static_quantized.state_dict():
    print(names)
print(model_static_quantized.model.features[2].weight().int_repr())
print(state.keys())
print(state['model.features.2.weight'])


---
===
===
===
model.features.0.weight
model.features.0.bias
model.features.0.scale
model.features.0.zero_point
model.features.2.weight
model.features.2.bias
model.features.2.scale
model.features.2.zero_point
model.features.5.weight
model.features.5.bias
model.features.5.scale
model.features.5.zero_point
model.features.7.weight
model.features.7.bias
model.features.7.scale
model.features.7.zero_point
model.features.10.weight
model.features.10.bias
model.features.10.scale
model.features.10.zero_point
model.features.12.weight
model.features.12.bias
model.features.12.scale
model.features.12.zero_point
model.features.14.weight
model.features.14.bias
model.features.14.scale
model.features.14.zero_point
model.features.17.weight
model.features.17.bias
model.features.17.scale
model.features.17.zero_point
model.features.19.weight
model.features.19.bias
model.features.19.scale
model.features.19.zero_point
model.features.21.weight
model.features.21.bias
model.features.21.scale
model.features.21.z