In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision.utils import _log_api_usage_once

In [2]:
import random
import numpy as np

def manual_seed(seed):
    np.random.seed(seed) #1
    random.seed(seed) #2
    torch.manual_seed(seed) #3
    torch.cuda.manual_seed(seed) #4.1
    torch.cuda.manual_seed_all(seed) #4.2
    torch.backends.cudnn.benchmark = False #5 
    torch.backends.cudnn.deterministic = True #6

manual_seed(42)

In [3]:
class Quantized(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        _log_api_usage_once(self)
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        pass


In [4]:
# filepath = './checkpoint/vgg16.pth'
# checkpoint = torch.load(filepath)
# model = checkpoint['model']
# print(model)
# model.load_state_dict(checkpoint['model_state_dict'])
model = torchvision.models.vgg.vgg16(pretrained=True)
model = Quantized(model)
pretrain_transforms = torchvision.models.VGG16_Weights.DEFAULT.transforms()
print(pretrain_transforms)



ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [5]:
from torchvision import transforms
from tqdm.auto import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    pretrain_transforms,
    ])

val_data = torchvision.datasets.ImageNet(root="./dataset/ImageNet", split="val", transform=test_transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=32,
                                          shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import psutil

def memory_usage(message = "debug"):
    p = psutil.Process()
    rss = p.memory_info().rss / 2**20
    print(f"{message} memory usage : {rss:10.5f} MB")

def test(model, test_loader, num_calib=None):
    model.eval()
    model.to('cpu')
    test_acc = 0
    with torch.inference_mode():
        for i,data in enumerate(tqdm(test_loader,leave=True)):
            imgs, target = data[0], data[1]
            # imgs, target = data[0].to(device), data[1].to(device)
            output = model(imgs)
            _, preds = torch.max(output.data, 1)
            test_acc += (preds==target).detach().sum().item()
            if num_calib and (i > num_calib):
                break
            # if i % 10 == 0:
            #     memory_usage()

    test_acc = 100. * test_acc/len(test_loader.dataset)
    

    return test_acc

In [7]:
modules = []
before_l = []
after_l = []
hooks = []

def hook_fn(module, input, output):
    modules.append(module)
    before_l.append(input[0])
    after_l.append(output)

def add_forward_hook(net, hooks):
    for name, layer in net._modules.items():
        if isinstance(layer, nn.Sequential) or isinstance(layer, torchvision.models.vgg.VGG):
            add_forward_hook(layer, hooks)
        else:
            hook = layer.register_forward_hook(hook_fn)
            hooks.append(hook)
            
    return hooks

def remove_forward_hook(hooks):
    for i in hooks:
        i.remove()
# out = model((torch.randn(1,3,32,32)))

In [8]:
model.to('cpu')
model.eval()
model.fuse_model()
backend = "fbgemm"
qconfig = torch.ao.quantization.get_default_qconfig(backend)
model.qconfig = qconfig
print(f"defualt qconfig ; {model.qconfig}")
torch.backends.quantized.engine = backend
print(torch.backends.quantized.engine)
model_static_quantized = torch.ao.quantization.prepare(model, inplace = False)
# calibration
test(model_static_quantized,val_loader,10)

# make quantized model
torch.ao.quantization.convert(model_static_quantized, inplace = True) 

# make hook
hooks = add_forward_hook(model_static_quantized, hooks)
sample = torch.randn(1,3,256,256)
model_static_quantized(sample)
print(len(hooks), len(modules), len(before_l), len(after_l))
# remove hook, hook works at once
remove_forward_hook(hooks)
# for _ in range(5):
#     sample = (torch.randn(1,3,224,224))
#     model_static_quantized(sample)
# print(len(modules), len(before_l), len(after_l))
print(model_static_quantized)



defualt qconfig ; QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
x86


  1%|          | 11/1563 [00:23<54:29,  2.11s/it]


41 41 41 41
Quantized(
  (model): VGG(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.17084623873233795, zero_point=58, padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.35520821809768677, zero_point=73, padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.5221890211105347, zero_point=82, padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.7225201725959778, zero_point=71, padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.9027827978134155, zero_point=74, padding=(1, 1))

In [9]:
test_acc = test(model_static_quantized, val_loader)
print(f"test result : {test_acc:.2f}")

 68%|██████▊   | 1060/1563 [08:21<03:57,  2.12it/s]

In [None]:
for i in modules:
    print(type(i))
print(len(modules), len(before_l), len(after_l))
for i in range(len(modules)):
    if not isinstance(type(modules[i]), torch.quantization.QuantStub) and not isinstance(type(modules[i]), torch.quantization.DeQuantStub):
        print(type(modules[i]))
        print(f"before : {type(before_l[i])}, {before_l[i].dtype}")
        print(f"after : {type(after_l[i])}, {after_l[i].dtype}")
        break

print(before_l[1])
print(after_l[1])

<class 'torch.ao.nn.quantized.modules.Quantize'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<cla

In [None]:
print(type(modules[2]))
print(before_l[2].dtype,before_l[2].int_repr())
print(after_l[2].dtype,after_l[2])

<class 'torch.nn.modules.activation.ReLU'>
torch.quint8 tensor([[[[68, 85, 61,  ..., 61, 59, 59],
          [81, 69, 59,  ..., 73, 59, 59],
          [80, 67, 59,  ..., 75, 83, 59],
          ...,
          [69, 74, 59,  ..., 59, 61, 92],
          [59, 63, 69,  ..., 59, 59, 79],
          [59, 59, 81,  ..., 59, 59, 78]],

         [[68, 59, 59,  ..., 71, 74, 66],
          [59, 67, 84,  ..., 83, 70, 59],
          [68, 59, 59,  ..., 87, 75, 61],
          ...,
          [59, 59, 79,  ..., 59, 68, 63],
          [59, 60, 87,  ..., 81, 67, 59],
          [59, 68, 80,  ..., 65, 59, 59]],

         [[70, 62, 62,  ..., 98, 59, 85],
          [59, 63, 65,  ..., 78, 59, 59],
          [59, 59, 59,  ..., 75, 69, 59],
          ...,
          [59, 59, 92,  ..., 70, 59, 59],
          [72, 90, 69,  ..., 81, 59, 73],
          [59, 74, 64,  ..., 59, 59, 82]],

         ...,

         [[59, 62, 65,  ..., 61, 62, 59],
          [60, 67, 66,  ..., 62, 66, 62],
          [59, 59, 69,  ..., 59, 61, 5

In [None]:
print(type(modules[-1]))
print(before_l[-1].dtype,before_l[-1])
print(after_l[-1].dtype,after_l[-1])

<class 'torch.ao.nn.quantized.modules.DeQuantize'>
torch.quint8 tensor([[ 0.3106,  2.7956, -1.2425, -1.0872,  0.0777,  2.2520,  1.3978,  0.6212,
          0.4659,  0.3106,  0.9319,  2.1743,  1.2425,  0.6989,  0.3883,  1.8637,
         -1.6308, -1.4754,  0.8542, -0.2330, -0.5436, -0.6989,  0.6989,  0.3883,
         -1.0095, -1.1648,  0.7766, -0.0777, -0.9319,  0.5436, -1.6308, -1.4754,
         -1.9414, -0.1553,  0.0777, -0.6989, -0.3106, -1.5531, -0.8542, -1.2425,
         -0.4659, -1.5531, -0.6989, -0.8542, -1.7861,  0.1553, -0.2330, -1.4754,
         -0.8542, -0.8542,  0.0000, -2.1743, -0.6212, -1.1648, -1.5531, -1.5531,
         -2.1743, -2.5626,  0.0000, -1.9414, -1.0095, -2.5626, -2.1743, -0.3106,
         -1.0095,  0.8542, -1.3978, -1.7084, -2.6403, -0.1553,  0.9319,  0.3883,
          0.3106,  1.3978,  1.4754,  1.5531,  0.6212,  2.0190,  2.7956,  1.8637,
          0.3106, -0.4659, -0.0777,  0.2330, -0.2330,  0.2330,  1.2425, -1.2425,
          1.3201,  0.0000,  0.2330,  0.4659, 

In [None]:
for child in model_static_quantized.children():
    if isinstance(child, nn.Sequential) or isinstance(child, torchvision.models.vgg.VGG):
        for n,c in child.named_children():
            for name, param in c.named_parameters():
                print(name)

state = model.state_dict()
for names in model_static_quantized.state_dict():
    print(names)
print(model_static_quantized.model.features[2].weight().int_repr())
print(state.keys())
print(state['model.features.2.weight'])


model.features.0.weight
model.features.0.bias
model.features.0.scale
model.features.0.zero_point
model.features.2.weight
model.features.2.bias
model.features.2.scale
model.features.2.zero_point
model.features.5.weight
model.features.5.bias
model.features.5.scale
model.features.5.zero_point
model.features.7.weight
model.features.7.bias
model.features.7.scale
model.features.7.zero_point
model.features.10.weight
model.features.10.bias
model.features.10.scale
model.features.10.zero_point
model.features.12.weight
model.features.12.bias
model.features.12.scale
model.features.12.zero_point
model.features.14.weight
model.features.14.bias
model.features.14.scale
model.features.14.zero_point
model.features.17.weight
model.features.17.bias
model.features.17.scale
model.features.17.zero_point
model.features.19.weight
model.features.19.bias
model.features.19.scale
model.features.19.zero_point
model.features.21.weight
model.features.21.bias
model.features.21.scale
model.features.21.zero_point
model.