In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision.utils import _log_api_usage_once

In [2]:
import random
import numpy as np

def manual_seed(seed):
    np.random.seed(seed) #1
    random.seed(seed) #2
    torch.manual_seed(seed) #3
    torch.cuda.manual_seed(seed) #4.1
    torch.cuda.manual_seed_all(seed) #4.2
    torch.backends.cudnn.benchmark = False #5 
    torch.backends.cudnn.deterministic = True #6

manual_seed(42)

In [3]:
class Quantized(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        _log_api_usage_once(self)
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        pass


In [4]:
# filepath = './checkpoint/vgg16.pth'
# checkpoint = torch.load(filepath)
# model = checkpoint['model']
# print(model)
# model.load_state_dict(checkpoint['model_state_dict'])
model = torchvision.models.vgg.vgg16(pretrained=True)
model = Quantized(model)
pretrain_transforms = torchvision.models.VGG16_Weights.DEFAULT.transforms()
print(pretrain_transforms)



ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [5]:
from torchvision import transforms
from tqdm.auto import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    pretrain_transforms,
    ])

val_data = torchvision.datasets.ImageNet(root="./dataset/ImageNet", split="val", transform=test_transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=32,
                                          shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import psutil

def memory_usage(message = "debug"):
    p = psutil.Process()
    rss = p.memory_info().rss / 2**20
    print(f"{message} memory usage : {rss:10.5f} MB")

def test(model, test_loader, num_calib=None):
    model.eval()
    model.to('cpu')
    test_acc = 0
    with torch.inference_mode():
        for i,data in enumerate(tqdm(test_loader,leave=True)):
            imgs, target = data[0], data[1]
            # imgs, target = data[0].to(device), data[1].to(device)
            output = model(imgs)
            _, preds = torch.max(output.data, 1)
            test_acc += (preds==target).detach().sum().item()
            if num_calib and (i > num_calib):
                break
            # if i % 10 == 0:
            #     memory_usage()

    test_acc = 100. * test_acc/len(test_loader.dataset)
    

    return test_acc

In [7]:
modules = []
before_l = []
after_l = []
hooks = []

def hook_fn(module, input, output):
    modules.append(module)
    before_l.append(input[0])
    after_l.append(output)

def add_forward_hook(net, hooks):
    for name, layer in net._modules.items():
        if isinstance(layer, nn.Sequential) or isinstance(layer, torchvision.models.vgg.VGG):
            add_forward_hook(layer, hooks)
        else:
            hook = layer.register_forward_hook(hook_fn)
            hooks.append(hook)
            
    return hooks

def remove_forward_hook(hooks):
    for i in hooks:
        i.remove()
# out = model((torch.randn(1,3,32,32)))

In [8]:
model.to('cpu')
model.eval()
model.fuse_model()
backend = "fbgemm"
qconfig = torch.ao.quantization.get_default_qconfig(backend)
model.qconfig = qconfig
print(f"defualt qconfig ; {model.qconfig}")
torch.backends.quantized.engine = backend
print(torch.backends.quantized.engine)
model_static_quantized = torch.ao.quantization.prepare(model, inplace = False)
# calibration
test(model_static_quantized,val_loader,10)

# make quantized model
torch.ao.quantization.convert(model_static_quantized, inplace = True) 

# make hook
hooks = add_forward_hook(model_static_quantized, hooks)
sample = torch.randn(1,3,256,256)
model_static_quantized(sample)
print(len(hooks), len(modules), len(before_l), len(after_l))
# remove hook, hook works at once
remove_forward_hook(hooks)
# for _ in range(5):
#     sample = (torch.randn(1,3,224,224))
#     model_static_quantized(sample)
# print(len(modules), len(before_l), len(after_l))
print(model_static_quantized)



defualt qconfig ; QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
x86


  1%|          | 11/1563 [00:23<54:29,  2.11s/it]


41 41 41 41
Quantized(
  (model): VGG(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.17084623873233795, zero_point=58, padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.35520821809768677, zero_point=73, padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.5221890211105347, zero_point=82, padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.7225201725959778, zero_point=71, padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.9027827978134155, zero_point=74, padding=(1, 1))

In [9]:
test_acc = test(model_static_quantized, val_loader)
print(f"test result : {test_acc:.2f}")

100%|██████████| 1563/1563 [12:21<00:00,  2.11it/s]

test result : 71.32





In [14]:
# save model
torch.jit.save(torch.jit.script(model_static_quantized),"./checkpoint/static_quant_fbegmm.pth")

In [10]:
for i in modules:
    print(type(i))
print(len(modules), len(before_l), len(after_l))
for i in range(len(modules)):
    if not isinstance(type(modules[i]), torch.quantization.QuantStub) and not isinstance(type(modules[i]), torch.quantization.DeQuantStub):
        print(type(modules[i]))
        print(f"before : {type(before_l[i])}, {before_l[i].dtype}")
        print(f"after : {type(after_l[i])}, {after_l[i].dtype}")
        break

print(before_l[1])
print(after_l[1])

<class 'torch.ao.nn.quantized.modules.Quantize'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.ao.nn.quantized.modules.conv.Conv2d'>
<cla

In [11]:
print(type(modules[2]))
print(before_l[2].dtype,before_l[2].int_repr())
print(after_l[2].dtype,after_l[2])

<class 'torch.nn.modules.activation.ReLU'>
torch.quint8 tensor([[[[66, 58, 58,  ..., 58, 83, 62],
          [71, 58, 58,  ..., 61, 88, 60],
          [64, 58, 64,  ..., 75, 63, 58],
          ...,
          [71, 58, 58,  ..., 70, 62, 58],
          [68, 74, 62,  ..., 64, 68, 60],
          [59, 77, 69,  ..., 59, 69, 61]],

         [[58, 58, 68,  ..., 58, 58, 58],
          [58, 71, 96,  ..., 65, 71, 75],
          [58, 82, 58,  ..., 78, 76, 62],
          ...,
          [70, 61, 58,  ..., 58, 58, 58],
          [58, 60, 65,  ..., 58, 63, 72],
          [65, 80, 71,  ..., 61, 67, 64]],

         [[58, 63, 71,  ..., 58, 67, 65],
          [58, 61, 71,  ..., 89, 58, 73],
          [76, 76, 58,  ..., 70, 60, 58],
          ...,
          [83, 58, 65,  ..., 58, 58, 58],
          [68, 72, 58,  ..., 67, 88, 71],
          [77, 91, 60,  ..., 66, 58, 58]],

         ...,

         [[61, 61, 58,  ..., 60, 58, 63],
          [66, 69, 58,  ..., 58, 59, 64],
          [61, 58, 58,  ..., 58, 62, 6

In [12]:
print(type(modules[-1]))
print(before_l[-1].dtype,before_l[-1])
print(after_l[-1].dtype,after_l[-1])

<class 'torch.ao.nn.quantized.modules.DeQuantize'>
torch.quint8 tensor([[ 0.4287,  3.0011, -0.8575, -0.8575,  0.4287,  3.8586,  2.5724,  0.4287,
          0.4287, -0.4287,  0.8575,  1.7149,  0.8575,  0.8575,  0.4287,  1.2862,
         -1.2862, -1.7149,  0.4287, -0.8575, -0.8575, -1.2862,  0.0000, -0.4287,
         -1.2862, -0.8575,  1.2862,  0.4287,  0.0000,  0.8575, -0.8575, -1.2862,
         -1.2862,  0.4287,  0.4287,  0.0000,  0.4287, -0.8575,  0.0000, -1.2862,
         -0.4287, -1.2862, -0.8575, -0.8575, -1.7149,  0.4287, -0.4287, -1.2862,
         -1.2862, -0.4287,  0.0000, -1.7149,  0.0000, -0.4287, -0.8575, -1.2862,
         -1.7149, -1.7149,  0.8575, -1.2862, -0.4287, -2.1437, -1.7149, -0.4287,
         -0.8575,  1.7149, -0.8575, -1.2862, -2.1437,  0.0000,  1.7149,  0.8575,
          0.8575,  2.1437,  2.1437,  2.1437,  0.8575,  2.1437,  3.4299,  2.5724,
          0.0000, -0.4287, -0.4287,  0.0000, -0.4287,  0.0000,  1.2862, -1.2862,
          0.8575, -0.4287, -0.4287, -0.4287, 

In [13]:
for child in model_static_quantized.children():
    if isinstance(child, nn.Sequential) or isinstance(child, torchvision.models.vgg.VGG):
        for n,c in child.named_children():
            for name, param in c.named_parameters():
                print(name)

state = model.state_dict()
for names in model_static_quantized.state_dict():
    print(names)
print(model_static_quantized.model.features[2].weight().int_repr())
print(state.keys())
print(state['model.features.2.weight'])


model.features.0.weight
model.features.0.bias
model.features.0.scale
model.features.0.zero_point
model.features.2.weight
model.features.2.bias
model.features.2.scale
model.features.2.zero_point
model.features.5.weight
model.features.5.bias
model.features.5.scale
model.features.5.zero_point
model.features.7.weight
model.features.7.bias
model.features.7.scale
model.features.7.zero_point
model.features.10.weight
model.features.10.bias
model.features.10.scale
model.features.10.zero_point
model.features.12.weight
model.features.12.bias
model.features.12.scale
model.features.12.zero_point
model.features.14.weight
model.features.14.bias
model.features.14.scale
model.features.14.zero_point
model.features.17.weight
model.features.17.bias
model.features.17.scale
model.features.17.zero_point
model.features.19.weight
model.features.19.bias
model.features.19.scale
model.features.19.zero_point
model.features.21.weight
model.features.21.bias
model.features.21.scale
model.features.21.zero_point
model.

Non Calibration

In [15]:
model = torchvision.models.vgg.vgg16(pretrained=True)
model = Quantized(model)

model.to('cpu')
model.eval()
model.fuse_model()
backend = "fbgemm"
qconfig = torch.ao.quantization.get_default_qconfig(backend)
model.qconfig = qconfig
print(f"defualt qconfig ; {model.qconfig}")
torch.backends.quantized.engine = backend
print(torch.backends.quantized.engine)
model_static_quantized = torch.ao.quantization.prepare(model, inplace = False)

# make quantized model
torch.ao.quantization.convert(model_static_quantized, inplace = True) 



defualt qconfig ; QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
fbgemm




Quantized(
  (model): VGG(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): QuantizedConv2d(256, 256, kernel_size=(3, 3), stri

In [16]:
nocali_test_acc = test(model_static_quantized, val_loader)
print(f"test result : {nocali_test_acc:.2f}")

100%|██████████| 1563/1563 [13:38<00:00,  1.91it/s]

test result : 5.98





Integer Sample Calibraiton

In [17]:
model = torchvision.models.vgg.vgg16(pretrained=True)
model = Quantized(model)

test_transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.Resize(256),
    transforms.PILToTensor(),
    ])
test_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32,
                                          shuffle=False)


model.to('cpu')
model.eval()
model.fuse_model()
backend = "fbgemm"
qconfig = torch.ao.quantization.get_default_qconfig(backend)
model.qconfig = qconfig
print(f"defualt qconfig ; {model.qconfig}")
torch.backends.quantized.engine = backend
print(torch.backends.quantized.engine)
model_static_quantized = torch.ao.quantization.prepare(model, inplace = False)
# calibraiotn
test(model_static_quantized,test_dataloader,10)

# make quantized model
torch.ao.quantization.convert(model_static_quantized, inplace = True) 



defualt qconfig ; QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
fbgemm


  1%|          | 11/1563 [00:22<53:59,  2.09s/it]


Quantized(
  (model): VGG(
    (features): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.17084623873233795, zero_point=58, padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.35520821809768677, zero_point=73, padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.5221890211105347, zero_point=82, padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.7225201725959778, zero_point=71, padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.9027827978134155, zero_point=74, padding=(1, 1))
      (11):

In [18]:
intcali_test_acc = test(model_static_quantized, test_dataloader)
print(f"test result : {intcali_test_acc:.2f}")

100%|██████████| 1563/1563 [13:32<00:00,  1.92it/s]

test result : 71.30





In [20]:
a = torch.randint(0,255,(3,3))
b = torch.randint(0,255,(3,3))

c_int32 = torch._int_mm(a,b)

AttributeError: module 'torch' has no attribute '_int_mm'