In [1]:
import copy
import math
import random
from collections import OrderedDict, defaultdict

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm.auto import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchvision.datasets import *
from torchvision.transforms import *

assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fab89ba86f0>

In [3]:
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True) # pretrained cifar10 model load
print(model)

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

Using cache found in /home/jsw/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [4]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  callbacks = None
) -> None:
  model.train()
  model.cuda()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update weight using optimizer
    optimizer.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [5]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  extra_preprocess = None
) -> float:
  model.eval()
  model.cuda()
  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    if extra_preprocess is not None:
        for preprocess in extra_preprocess:
            inputs = preprocess(inputs)

    targets = targets.cuda()

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

In [6]:
image_size = 32
transforms = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]),
    "test": Compose([ToTensor(), Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
}
dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="/home/jsw/data/cifar10",
    train=(split == "train"),
    download=False,
    transform=transforms[split],
  )
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=512,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
  )

In [7]:
fp32_model_accuracy = evaluate(model, dataloader['test'])
print(f"fp32 model has accuracy={fp32_model_accuracy:.2f}%")

                                                     

fp32 model has accuracy=92.59%




In [8]:
# Let's First Evaluate the Accuracy and Model Size of the FP32 Pretrained Model

In [9]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F


class _quantize_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, step_size, half_lvls):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.step_size = step_size
        ctx.half_lvls = half_lvls
        output = F.hardtanh(input,
                            min_val=-ctx.half_lvls * ctx.step_size.item(),
                            max_val=ctx.half_lvls * ctx.step_size.item())

        output = torch.round(output / ctx.step_size)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone() / ctx.step_size

        return grad_input, None, None


quantize = _quantize_func.apply


class QuantConv2d(nn.Conv2d):
    N_bits = 8
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        super(QuantConv2d, self).__init__(in_channels,
                                          out_channels,
                                          kernel_size,
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation,
                                          groups=groups,
                                          bias=bias)
        self.full_lvls = 2**self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2
        # Initialize the step size
        self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        self.__reset_stepsize__()
        # flag to enable the inference with quantized weight or self.weight
        self.inf_with_weight = False  # disabled by default

        # create a vector to identify the weight to each bit
        self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
                                                end=-1,
                                                step=-1).unsqueeze(-1).float(),
                                requires_grad=False)

        self.b_w[0] = -self.b_w[0]  #in-place change MSB to negative

    def forward(self, input):
        if self.inf_with_weight:
            return F.conv2d(input, self.weight * self.step_size, self.bias,
                            self.stride, self.padding, self.dilation,
                            self.groups)
        else:
            self.__reset_stepsize__()
            weight_quan = quantize(self.weight, self.step_size,
                                   self.half_lvls) * self.step_size
            input_step_size = input.max() / (self.full_lvls -1) # unsigned input
            input_quan = quantize(input, input_step_size, (self.full_lvls -1)) * input_step_size
            
            #input_quan = input
            return F.conv2d(input_quan, weight_quan, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

    def __reset_stepsize__(self):
        with torch.no_grad():
            self.step_size.data = self.weight.abs().max() / self.half_lvls

    def __reset_weight__(self):
        '''
        This function will reconstruct the weight stored in self.weight.
        Replacing the original floating-point with the quantized fix-point
        weight representation.
        '''
        # replace the weight with the quantized version
        with torch.no_grad():
            self.weight.data = quantize(self.weight, self.step_size,
                                        self.half_lvls)
        # enable the flag, thus now computation does not invovle weight quantization
        self.inf_with_weight = True
    
    def __reset_half_lvls__(self):
        '''
        recunstruct half_lvls
        '''
        self.full_lvls = 2**self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2

class QuantLinear(nn.Linear):
    N_bits = 8
    def __init__(self, in_features, out_features, bias=True):
        super(QuantLinear, self).__init__(in_features, out_features, bias=bias)

        self.full_lvls = 2**self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2
        # Initialize the step size
        self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        self.__reset_stepsize__()
        # flag to enable the inference with quantized weight or self.weight
        self.inf_with_weight = False  # disabled by default

        # create a vector to identify the weight to each bit
        self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
                                                end=-1,
                                                step=-1).unsqueeze(-1).float(),
                                requires_grad=False)

        self.b_w[0] = -self.b_w[0]  #in-place reverse

    def forward(self, input):
        if self.inf_with_weight:
            return F.linear(input, self.weight * self.step_size, self.bias)
        else:
            self.__reset_stepsize__()
            weight_quan = quantize(self.weight, self.step_size,
                                   self.half_lvls) * self.step_size
            input_step_size = input.max() / (self.full_lvls -1) # unsigned input
            input_quan = quantize(input, input_step_size, (self.full_lvls -1)) * input_step_size
            
            #input_quan = input
            return F.linear(input_quan, weight_quan, self.bias)

    def __reset_stepsize__(self):
        with torch.no_grad():
            self.step_size.data = self.weight.abs().max() / self.half_lvls

    def __reset_weight__(self):
        '''
        This function will reconstruct the weight stored in self.weight.
        Replacing the orginal floating-point with the quantized fix-point
        weight representation.
        '''
        # replace the weight with the quantized version
        with torch.no_grad():
            self.weight.data = quantize(self.weight, self.step_size,
                                        self.half_lvls)
        # enable the flag, thus now computation does not invovle weight quantization
        self.inf_with_weight = True

    def __reset_half_lvls__(self):
        '''
        recunstruct half_lvls
        '''
        self.full_lvls = 2**self.N_bits
        self.half_lvls = (self.full_lvls - 2) / 2


In [10]:
class Hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.weight = module.weight
        self.output = output

        self.approx_weight = quantize(self.weight, module.step_size, module.half_lvls) * module.step_size
        self.input_step_size = self.input[0].max() / (module.full_lvls -1) # unsigned input
        self.approx_input = quantize(self.input[0], self.input_step_size, module.full_lvls -1) * self.input_step_size
        self.approx_output = F.conv2d(self.approx_input, self.approx_weight, module.bias, module.stride,
                            module.padding, module.dilation, module.groups)
        
    def close(self):
        self.hook.remove()

In [11]:
testConv = QuantConv2d(3, 64, kernel_size=3, stride=1, padding=1)
inputs = torch.randn(4, 3, 32, 32)
hook_list = []
for name, module in testConv.named_modules():
    if isinstance(module, QuantConv2d):
        hook_list.append(Hook(module))

testConv(inputs)
for hook in hook_list:
    print(hook.input[0].shape)
    print(hook.weight.shape)
    print(hook.output.shape)
    print(hook.approx_weight.shape)
    print(hook.approx_input.shape)
    print(hook.approx_output - hook.output)

torch.Size([4, 3, 32, 32])
torch.Size([64, 3, 3, 3])
torch.Size([4, 64, 32, 32])
torch.Size([64, 3, 3, 3])
torch.Size([4, 3, 32, 32])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
  

In [12]:
print(model)

def set_deep_attr(obj, attrs, value):
    for attr in attrs.split(".")[:-1]:
        obj = getattr(obj, attr)
    setattr(obj, attrs.split(".")[-1], value)


def change_model(model):
    copy_model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) and 'downsample' not in name:
            set_deep_attr(copy_model, name, QuantConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias))
            for p_name, p in module.named_parameters():
                set_deep_attr(copy_model, name + '.' + p_name, p)
        elif isinstance(module, nn.Linear):
            set_deep_attr(copy_model, name, QuantLinear(module.in_features, module.out_features, True if module.bias is not None else False))
            for p_name, p in module.named_parameters():
                set_deep_attr(copy_model, name + '.' + p_name, p)
    return copy_model

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [13]:
c_model = change_model(model)
print(c_model)

CifarResNet(
  (conv1): QuantConv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QuantConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): QuantConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): QuantConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): QuantConv2d(16, 16, kernel_size=(3, 3), stride=(1,

In [14]:
ptq_int8_model_accuracy = evaluate(c_model, dataloader['test'])
print(f"int8 ptq model has accuracy={ptq_int8_model_accuracy:.2f}%")

                                                     

int8 ptq model has accuracy=92.41%




In [15]:
QuantConv2d.N_bits = 6
c_model = change_model(model)
ptq_int6_model_accuracy = evaluate(c_model, dataloader['test'])
print(f"int6 ptq model has accuracy={ptq_int6_model_accuracy:.2f}%")

eval:   0%|          | 0/20 [00:00<?, ?it/s]

                                                     

int6 ptq model has accuracy=91.86%




In [16]:
class Quantized_Hook:
    def __init__(self, name, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.name = name
    def hook_fn(self, module, input, output):
        self.module = module
        self.input = input
        self.weight = module.weight
        self.output = output
        self.quan_weight = quantize(self.weight, module.step_size, module.half_lvls)
        self.input_step_size = self.input[0].max() / (module.full_lvls -1) # unsigned input
        self.weight_step_size = module.step_size
        self.quan_input = quantize(self.input[0], self.input_step_size, module.full_lvls -1)
        self.quan_output = F.conv2d(self.quan_input, self.quan_weight, module.bias, module.stride,
                            module.padding, module.dilation, module.groups)
        
    def close(self):
        self.hook.remove()

In [17]:
QuantConv2d.N_bits=8
inputs =torch.randn(1, 3, 8, 8)
inputs = inputs.abs()
print(len(inputs[inputs==0]))
quant_conv = QuantConv2d(3, 4, kernel_size=3, stride=1, padding=1, bias=False)
hook_list = []
quant_hook = Quantized_Hook("test",quant_conv)
out = quant_conv(inputs)
print(quant_hook.quan_output * quant_hook.input_step_size * quant_hook.weight_step_size - out)

0
tensor([[[[ 8.9407e-08, -7.4506e-09,  2.9802e-08,  2.9802e-08,  0.0000e+00,
            0.0000e+00, -1.4901e-08,  2.9802e-08],
          [ 0.0000e+00,  1.4901e-07,  3.1665e-08,  0.0000e+00,  1.1921e-07,
           -5.2154e-08,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -1.4901e-08, -8.9407e-08,  0.0000e+00,  8.9407e-08,
            2.9802e-08, -5.9605e-08,  5.9605e-08],
          [ 0.0000e+00, -5.9605e-08,  0.0000e+00,  2.9802e-08,  5.9605e-08,
            5.9605e-08,  1.1921e-07,  5.9605e-08],
          [ 2.9802e-08, -5.9605e-08,  0.0000e+00,  5.9605e-08, -5.9605e-08,
            0.0000e+00, -1.7881e-07,  0.0000e+00],
          [-3.7253e-09,  8.9407e-08, -4.8429e-08,  3.7253e-09, -7.4506e-09,
            5.9605e-08, -5.9605e-08,  4.4703e-08],
          [-1.4901e-08,  0.0000e+00, -9.3132e-10, -1.4901e-08,  0.0000e+00,
           -1.1921e-07,  0.0000e+00,  0.0000e+00],
          [ 2.2352e-08, -1.1921e-07, -7.4506e-09,  1.4901e-08,  5.9605e-08,
            5.9605e-08, -5.9605e-

In [18]:
weight_tensor = copy.deepcopy(quant_hook.quan_weight.data)
input_tensor = copy.deepcopy(quant_hook.quan_input.data)
print(input_tensor.shape)
print(weight_tensor.shape)
fold_param = dict(kernel_size=quant_hook.module.kernel_size, dilation=quant_hook.module.dilation, padding=quant_hook.module.padding, stride=quant_hook.module.stride)
unfold_module = nn.Unfold(**fold_param)
unfold_out = unfold_module(input_tensor)
print(unfold_out.shape)
weight_2d = weight_tensor.reshape(weight_tensor.shape[0], -1)
print(weight_tensor.shape)
out = F.conv2d(input_tensor, weight_tensor, stride=fold_param['stride'], padding=fold_param['padding'])
print(out - F.linear(unfold_out.transpose(1, 2), weight_2d, bias=None).transpose(1,2).reshape(out.shape))
F.linear(unfold_out.transpose(1, 2), weight_2d, bias=None).shape

torch.Size([1, 3, 8, 8])
torch.Size([4, 3, 3, 3])
torch.Size([1, 27, 64])
torch.Size([4, 3, 3, 3])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0.

torch.Size([1, 64, 4])

In [19]:
### Y = W*X ==> Y = (X * W).transpose(1, 2)

print("input matrix shape : ", unfold_out.shape) # (batch size, kernel_size * kernel_size * in_channels, output_size * output_size) 
unfold_input = unfold_out.transpose(1, 2) # (batch size, output_size * output_size, kernel_size * kernel_size * in_channels)
print("weight matrix shape : ", weight_tensor.shape) # (out_channels, kernel_size * kernel_size * in_channels)

weight_2d = weight_tensor.reshape(weight_tensor.shape[0], -1) 
print("weight matrix shape : ", weight_2d.shape) # (out_channels, kernel_size * kernel_size * in_channels)
weight_2d_dot = weight_2d.transpose(0,1) # (kernel_size * kernel_size * in_channels, out_channels)
# dot product (input matrix, weight matrix)
matmul_result = torch.matmul(unfold_input, weight_2d_dot)
print("matmul_result shape : ", matmul_result.shape) # (batch size, output_size * output_size, out_channels)
print("re generate conv2d shape : ", matmul_result.transpose(1,2).reshape(out.shape).shape) # (batch size, out_channels, output_size, output_size)
print("compare conv2d and im2col dot product result : ",out - matmul_result.transpose(1,2).reshape(out.shape))

input matrix shape :  torch.Size([1, 27, 64])
weight matrix shape :  torch.Size([4, 3, 3, 3])
weight matrix shape :  torch.Size([4, 27])
matmul_result shape :  torch.Size([1, 64, 4])
re generate conv2d shape :  torch.Size([1, 4, 8, 8])
compare conv2d and im2col dot product result :  tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]],

        

In [20]:
QuantConv2d.N_bits = 8
c_model = change_model(model)
for data, label in dataloader['test']:
    break

inputs = data[:1].cuda()

hook_list = []

for name, module in c_model.named_modules():
    if isinstance(module, QuantConv2d):
        hook_list.append(Quantized_Hook(name, module))
 
c_model(inputs)

tensor([[-2.8645, -2.7106, -1.7416, 21.1250, -2.9348,  1.5684, -2.6374, -2.5502,
         -4.6618, -2.6044]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [21]:
for hook in hook_list:
    print(hook.name)
    print(f"input shape {hook.input[0].shape}, weight shape {hook.weight.shape}")
    wq = hook.quan_weight.data.view(hook.quan_weight.shape[0], -1).transpose(0, 1) #  output_channels, in_channels, kernel_size, kernel_size) => (kernel_size * kernel_size * in_c, output_channels)
    wq = wq + 128 # -128 ~ 127 to 0 ~ 255
    wq = wq.type(torch.int32).to(device='cpu')
    inq = hook.quan_input.data
    unfold_param = dict(kernel_size=hook.module.kernel_size, dilation=hook.module.dilation, padding=hook.module.padding, stride=hook.module.stride)
    unfold_inq = nn.Unfold(**unfold_param)(inq)
    unfold_inq = unfold_inq.transpose(1, 2).squeeze(0) # (batch_size (=1), kernel_size * kernel_size * in_channels, output_size * output_size) => (output_size * output_size, kernel_size * kernel_size * in_channels)
    unfold_inq = unfold_inq.type(torch.int32).to(device='cpu')
    matmul_out = torch.matmul(unfold_inq, wq).to(torch.int32) # (output_size * output_size, kernel_size * kernel_size * in_channels) * (kernel_size * kernel_size * in_channels, output_channels) => (output_size * output_size, output_channels) 
    print(f"input 2d shape :", unfold_inq.shape)
    print(f"weight 2d shape :", wq.shape)
    print(f"matmul out shape :", matmul_out.shape)
    unfold_inq_np = unfold_inq.numpy()
    wq_np = wq.numpy()
    matmul_out_np = matmul_out.numpy()
    
    np.savetxt(f"/app/tensor_result/{hook.name}_input.txt", unfold_inq_np, fmt='%d')
    np.savetxt(f"/app/tensor_result/{hook.name}_weight.txt", wq_np, fmt='%d')
    np.savetxt(f"/app/tensor_result/{hook.name}_matmul_out.txt", matmul_out_np, fmt='%d')

conv1
input shape torch.Size([1, 3, 32, 32]), weight shape torch.Size([16, 3, 3, 3])
input 2d shape : torch.Size([1024, 27])
weight 2d shape : torch.Size([27, 16])
matmul out shape : torch.Size([1024, 16])
layer1.0.conv1
input shape torch.Size([1, 16, 32, 32]), weight shape torch.Size([16, 16, 3, 3])
input 2d shape : torch.Size([1024, 144])
weight 2d shape : torch.Size([144, 16])
matmul out shape : torch.Size([1024, 16])
layer1.0.conv2
input shape torch.Size([1, 16, 32, 32]), weight shape torch.Size([16, 16, 3, 3])
input 2d shape : torch.Size([1024, 144])
weight 2d shape : torch.Size([144, 16])
matmul out shape : torch.Size([1024, 16])
layer1.1.conv1
input shape torch.Size([1, 16, 32, 32]), weight shape torch.Size([16, 16, 3, 3])
input 2d shape : torch.Size([1024, 144])
weight 2d shape : torch.Size([144, 16])
matmul out shape : torch.Size([1024, 16])
layer1.1.conv2
input shape torch.Size([1, 16, 32, 32]), weight shape torch.Size([16, 16, 3, 3])
input 2d shape : torch.Size([1024, 144])


In [22]:
matmul_out.max()

tensor(996891, dtype=torch.int32)

In [23]:
g = torch.IntTensor([100, 200])
g = g.type(torch.uint8)
print(g)
print(g*g)

tensor([100, 200], dtype=torch.uint8)
tensor([16, 64], dtype=torch.uint8)
