This notebook is meant for providing a detailed description of MetaQuant via code and mathematical explanation.

## Motivation
Training-based quantization aims at minimizing the following training loss:

$$\min \ell = \text{Loss}(f(Q(\mathbf{W}, \mathbf{x})))$$

where $Q(\cdot)$ quantize full-precision weight $\mathbf{W}$ into quantized value $\mathbf{\hat{W}}$. Due to the non-differentiability of $Q(\cdot)$, the gradient of $\ell$ w.r.t $\mathbf{W}$ cannot be attained in a normal way.
To enable a stable quantization training, Straight-Through-Estimator (STE) is proposed to redefine $\partial Q(r)/\partial r$:

$$\frac{\partial Q(r)}{\partial r}=\left\{\begin{matrix}1&\quad\text{if}\quad|r|\leq1,\\0&\quad\text{otherwise.}\end{matrix}\right.$$

However, it inevitably brings the problem of **gradient mismatch**: the gradients of the weights are not generated using the value of weights, but rather its quantized value. Although STE provides an end-to-end training method under discrete constraints, few works have progressed to investigate how to obtain better gradients for quantization training. 

To overcome the problem of gradient mismatch and explore better gradients in training-based methods, we propose to learn $\frac{\partial Q(\mathbf{W})}{\partial \mathbf{W}}$ by a neural network ($\mathcal{W}$) during quantization training. Such neural network is called **meta quantizer** and is trained together with the base quantized model. This process is named as **Meta** **Quant**ization (MetaQuant). 

Specially, in each backward propagation, $\mathcal{W}$ takes $\frac{\partial \ell}{\partial Q(\mathbf{W})}$ and $\mathbf{W}$ as inputs in a coordinate-wise manner, then its output is assigned to $\frac{\partial \ell}{\partial Q(\mathbf{W})}$ for weights update using common optimization methods such as SGD and Adam. In the forward pass, inference is conducted using the quantized version of the updated weights, which produce the final outputs to be compared with the ground-truth labels for backward computation. During this process, gradient propagation from the quantized weights to the full-precision weights is handled by $\mathcal{M}$, which avoids the problem of non-differentiability and gradient mismatch. Besides, the gradients generated by the \meta are loss-aware, contributing to better performance of the quantization training.

## Overflow of MetaQuant
<!--
![Overflow of MetaQuant]('./figs/MetaQuant.png')
-->

<img src="figs/MetaQuant.png">

In [4]:
"""
This block import packages and initialize key parameters
"""
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import shutil
import pickle
import time
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim

from utils.dataset import get_dataloader
from meta_utils.meta_network import MetaFC, MetaLSTMFC, MetaDesignedMultiFC
from meta_utils.SGD import SGD
from meta_utils.adam import Adam
from meta_utils.helpers import meta_gradient_generation, update_parameters
from utils.recorder import Recorder
from utils.miscellaneous import AverageMeter, accuracy, progress_bar
from utils.miscellaneous import get_layer
from utils.quantize import test

from models_CIFAR.quantized_meta_resnet import resnet20_cifar

# ------------------------------------------
use_cuda = torch.cuda.is_available()
model_name = 'ResNet20'
dataset_name = 'CIFAR10'
meta_method = 'MultiFC' # ['LSTMFC', 'MultiFC', 'FC-Grad']
MAX_EPOCH = 100
optimizer_type = 'SGD' # ['SGD', 'SGD-M', 'adam']
hidden_size = 100
num_lstm = 2
num_fc = 3
lr_adjust = '30'
batch_size = 128
bitW = 1
quantized_type = 'dorefa' # ['dorefa', 'BWN']
save_root = './Results/%s-%s' % (model_name, dataset_name)
meta_nonlinear = None
weight_decay = 0
init_lr = 1e-3
exp_spec = ''
# ------------------------------------------

In [2]:
"""
This block initialize network and load dataset
"""

import utils.global_var as gVar
gVar.meta_count = 0

###################
# Initial Network #
###################
net = resnet20_cifar(bitW=bitW)
pretrain_path = '%s/%s-%s-pretrain.pth' % (save_root, model_name, dataset_name)
net.load_state_dict(torch.load(pretrain_path), strict=False)

# Get layer name list
layer_name_list = net.layer_name_list
# Assert all required layer is initialized as meta layer
assert (len(layer_name_list) == gVar.meta_count)
print('Layer name list completed.')

if use_cuda:
    net.cuda()
    
################
# Load Dataset #
################
train_loader = get_dataloader(dataset_name, 'train', batch_size)
test_loader = get_dataloader(dataset_name, 'test', 100)

Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized CNN with bit 1
Initial Meta-Quantized Linear with bit 1
Layer name list completed.
[2023-08-21 09:16:01.910870] Loading train from CIFAR10
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/c

100.0%


Number of training instances used: 50000
[DATA LOADING] Loading from CIFAR10-train finish. Number of images: 50000, Number of batches: 391
[2023-08-21 09:16:40.696436] Loading test from CIFAR10
Found CIFAR10 in /home/chenxinquan/MetaQuant/datasets/CIFAR10
Files already downloaded and verified
[DATA LOADING] Loading from CIFAR10-test finish. Number of images: 10000, Number of batches: 100


In [5]:
"""
This block initialize meta network, optimizer and recorder
"""
########################
# Initial Meta Network #
########################
if meta_method == 'LSTMFC':
    meta_net = MetaLSTMFC(hidden_size=hidden_size)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-nlstm-1-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, meta_nonlinear, hidden_size,
                     quantized_type, optimizer_type, bitW, lr_adjust)
elif meta_method in ['FC-Grad']:
    meta_net = MetaFC(hidden_size=hidden_size, use_nonlinear=meta_nonlinear)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, meta_nonlinear, hidden_size,
                     quantized_type, optimizer_type, bitW, lr_adjust)
elif meta_method == 'MultiFC':
    meta_net = MetaDesignedMultiFC(hidden_size=hidden_size,
                                   num_layers = num_fc,
                                   use_nonlinear=meta_nonlinear)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-nfc-%d-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, meta_nonlinear, hidden_size, num_fc,
                     quantized_type, optimizer_type, bitW, lr_adjust)
else:
    raise NotImplementedError

print(meta_net)

if use_cuda:
    meta_net.cuda()

meta_optimizer = optim.Adam(meta_net.parameters(), lr=1e-3, weight_decay=weight_decay)
    
#####################
# Initial Optimizee #
#####################
    
# Optimizer for original network, just for zeroing gradient and get refined gradient
if optimizer_type == 'SGD-M':
    optimizee = SGD(net.parameters(), lr=init_lr,
                    momentum=0.9, weight_decay=5e-4)
elif optimizer_type == 'SGD':
    optimizee = SGD(net.parameters(), lr=init_lr)
elif optimizer_type in ['adam', 'Adam']:
    optimizee = Adam(net.parameters(), lr=init_lr,
                     weight_decay=5e-4)
else:
    raise NotImplementedError
    
####################
# Initial Recorder #
####################
if exp_spec is not '':
    SummaryPath += ('-' + exp_spec)

print('Save to %s' %SummaryPath)

if os.path.exists(SummaryPath):
    print('Record exist, remove')
    input()
    shutil.rmtree(SummaryPath)
    os.makedirs(SummaryPath)
else:
    os.makedirs(SummaryPath)

recorder = Recorder(SummaryPath=SummaryPath, dataset_name=dataset_name)

MetaDesignedMultiFC(
  (network): Sequential(
    (Linear0): Linear(in_features=1, out_features=100, bias=False)
    (Linear1): Linear(in_features=100, out_features=100, bias=False)
    (Linear2): Linear(in_features=100, out_features=1, bias=False)
  )
)
Save to ./Results/ResNet20-CIFAR10/runs-Quant/Meta-MultiFC-Nonlinear-None-hidden-size-100-nfc-3-dorefa-SGD-1bits-lr-30


  if exp_spec is not '':


In training of MetaQuant, in order to train meta quantizer, the meta gradient produced by meta quantizer is added to the inference process to get loss:
在MetaQuant的训练中，为了训练元量化器，将元量化器产生的梯度添加到推理过程中以获得损失

<img src="figs/MetaQuant-Forward.png">

Therefore, the forward process in base net is modified to incorprate meta gradient embedded using ```meta_grad_dict```.

In [6]:
"""
This block begins training
"""
meta_hidden_state_dict = dict() # Dictionary to store hidden states for all layers for memory-based meta network
meta_grad_dict = dict() # Dictionary to store meta net output: gradient for origin network's weight / bias

for epoch in range(MAX_EPOCH):

    if recorder.stop: break

    print('\nEpoch: %d, lr: %e' % (epoch, optimizee.param_groups[0]['lr']))

    net.train()
    end = time.time()

    recorder.reset_performance()

    for batch_idx, (inputs, targets) in enumerate(train_loader):

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        meta_optimizer.zero_grad()

        # In first iteration of whole training, meta gradient hasn't been generated, 
        # therefore the first forward is conducted without meta gradient.
        if batch_idx == 0 and epoch == 0:
            pass
        # meta gradient used in current iteration is generated by the gradient and weights 
        # from previous iteration.
        else:
            meta_grad_dict, meta_hidden_state_dict = \
                meta_gradient_generation(
                        meta_net, net, meta_method, meta_hidden_state_dict
                )
        # Conduct forward using meta gradient
        outputs = net(inputs, quantized_type=quantized_type,
                      meta_grad_dict=meta_grad_dict,
                      lr=optimizee.param_groups[0]['lr'])

        optimizee.zero_grad()

        # Taking backward generate gradient for meta pruner and base model
        # Non-meta-weights' (bias, BN layer) gradient is attained here
        losses = nn.CrossEntropyLoss()(outputs, targets)
        losses.backward()

        meta_optimizer.step()

        # Assign meta gradient for actual gradients used in update_parameters
        if len(meta_grad_dict) != 0:
            for layer_info in net.layer_name_list:
                layer_name = layer_info[0]
                layer_idx = layer_info[1]
                layer = get_layer(net, layer_idx)
                layer.weight.grad.data = (layer.calibration * layer.pre_quantized_grads)
                # layer.weight.grad.data.copy_(layer.calibration * meta_grad_dict[layer_name][1].data)

        # Get refine gradients for next computation
        optimizee.get_refine_gradient()

        # These gradient should be saved in next iteration's inference
        if len(meta_grad_dict) != 0:
            update_parameters(net, lr=optimizee.param_groups[0]['lr'])

        recorder.update(loss=losses.data.item(), acc=accuracy(outputs.data, targets.data, (1,5)),
                        batch_size=outputs.shape[0], cur_lr=optimizee.param_groups[0]['lr'], end=end)

        recorder.print_training_result(batch_idx, len(train_loader))
        end = time.time()

    test_acc = test(net, quantized_type=quantized_type, test_loader=test_loader,
                    dataset_name=dataset_name, n_batches_used=None)
    recorder.update(loss=None, acc=test_acc, batch_size=0, end=None, is_train=False)

    # Adjust learning rate
    recorder.adjust_lr(optimizer=optimizee, adjust_type=lr_adjust, epoch=epoch)

best_test_acc = recorder.get_best_test_acc()
if type(best_test_acc) == tuple:
    print('Best test top 1 acc: %.3f, top 5 acc: %.3f' % (best_test_acc[0], best_test_acc[1]))
else:
    print('Best test acc: %.3f' %best_test_acc)
recorder.close()


Epoch: 0, lr: 1.000000e-03


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
meta_grad_dict
{}
meta_grad_dict
{}
meta_grad_dict
{'conv1': ([...], tensor([[[[ 5.5641e-...ackward0>), None), 'fc': ([...], tensor([[-1.8364e-02...ackward0>), tensor([ 0.1355, -0....='cuda:0')), 'layer1.0.conv1': ([...], tensor([[[[-5.3376e-...ackward0>), None), 'layer1.0.conv2': ([...], tensor([[[[-5.9805e-...ackward0>), None), 'layer1.1.conv1': ([...], tensor([[[[ 5.0621e-...ackward0>), None), 'layer1.1.conv2': ([...], tensor([[[[ 3.4084e-...ackward0>), None), 'layer1.2.conv1': ([...], tensor([[[[ 1.0782e-...ackward0>), None), 'layer1.2.conv2': ([...], tensor([[[[ 1.2057e-...ackward0>), None), 'layer1.3.conv1': ([...], tensor([[[[ 4.1520e-...ackward0>), None), 'layer1.3.conv2': ([...], tensor([[[[-1.1259e-...ackward0>), None), 'layer1.4.conv1': ([...], tensor([[[[ 2.9680e-...ackward0>), None), 'layer1.4.conv2': ([...], tensor([[[[-3.7319e-...ackward0>), None), 'layer2.0.downsample.0': ([...], tensor([[[[ 8.5688e-...ackward0>), None), 'layer2.0.conv1': ([...], tensor([[[[-1.9360e-...ackward0>), None), ...}
special variables:
function variables:
'conv1': (['conv1'], tensor([[[[ 5.5641e-...ackward0>), None)
'fc': (['fc'], tensor([[-1.8364e-02...ackward0>), tensor([ 0.1355, -0....='cuda:0'))
'layer1.0.conv1': (['layer1', 0, 'conv1'], tensor([[[[-5.3376e-...ackward0>), None)
'layer1.0.conv2': (['layer1', 0, 'conv2'], tensor([[[[-5.9805e-...ackward0>), None)
'layer1.1.conv1': (['layer1', 1, 'conv1'], tensor([[[[ 5.0621e-...ackward0>), None)
'layer1.1.conv2': (['layer1', 1, 'conv2'], tensor([[[[ 3.4084e-...ackward0>), None)
'layer1.2.conv1': (['layer1', 2, 'conv1'], tensor([[[[ 1.0782e-...ackward0>), None)
'layer1.2.conv2': (['layer1', 2, 'conv2'], tensor([[[[ 1.2057e-...ackward0>), None)
'layer1.3.conv1': (['layer1', 3, 'conv1'], tensor([[[[ 4.1520e-...ackward0>), None)
'layer1.3.conv2': (['layer1', 3, 'conv2'], tensor([[[[-1.1259e-...ackward0>), None)
'layer1.4.conv1': (['layer1', 4, 'conv1'], tensor([[[[ 2.9680e-...ackward0>), None)
'layer1.4.conv2': (['layer1', 4, 'conv2'], tensor([[[[-3.7319e-...ackward0>), None)
'layer2.0.downsample.0': (['layer2', 0, 'downsample', 0], tensor([[[[ 8.5688e-...ackward0>), None)
'layer2.0.conv1': (['layer2', 0, 'conv1'], tensor([[[[-1.9360e-...ackward0>), None)
'layer2.0.conv2': (['layer2', 0, 'conv2'], tensor([[[[ 8.1422e-...ackward0>), None)
'layer2.1.conv1': (['layer2', 1, 'conv1'], tensor([[[[ 2.2410e-...ackward0>), None)
'layer2.1.conv2': (['layer2', 1, 'conv2'], tensor([[[[ 1.9180e-...ackward0>), None)
'layer2.2.conv1': (['layer2', 2, 'conv1'], tensor([[[[ 4.2085e-...ackward0>), None)
'layer2.2.conv2': (['layer2', 2, 'conv2'], tensor([[[[ 5.4455e-...ackward0>), None)
'layer2.3.conv1': (['layer2', 3, 'conv1'], tensor([[[[-3.9597e-...ackward0>), None)
'layer2.3.conv2': (['layer2', 3, 'conv2'], tensor([[[[-1.3790e-...ackward0>), None)
'layer2.4.conv1': (['layer2', 4, 'conv1'], tensor([[[[-8.1013e-...ackward0>), None)
'layer2.4.conv2': (['layer2', 4, 'conv2'], tensor([[[[ 5.9578e-...ackward0>), None)
'layer3.0.downsample.0': (['layer3', 0, 'downsample', 0], tensor([[[[-2.9712e-...ackward0>), None)
'layer3.0.conv1': (['layer3', 0, 'conv1'], tensor([[[[-3.2160e-...ackward0>), None)
'layer3.0.conv2': (['layer3', 0, 'conv2'], tensor([[[[ 2.2531e-...ackward0>), None)
'layer3.1.conv1': (['layer3', 1, 'conv1'], tensor([[[[ 3.7664e-...ackward0>), None)
'layer3.1.conv2': (['layer3', 1, 'conv2'], tensor([[[[-6.5114e-...ackward0>), None)
'layer3.2.conv1': (['layer3', 2, 'conv1'], tensor([[[[ 5.6294e-...ackward0>), None)
'layer3.2.conv2': (['layer3', 2, 'conv2'], tensor([[[[ 1.2066e-...ackward0>), None)
'layer3.3.conv1': (['layer3', 3, 'conv1'], tensor([[[[-7.5203e-...ackward0>), None)
'layer3.3.conv2': (['layer3', 3, 'conv2'], tensor([[[[ 1.8765e-...ackward0>), None)
'layer3.4.conv1': (['layer3', 4, 'conv1'], tensor([[[[ 1.1682e-...ackward0>), None)
'layer3.4.conv2': (['layer3', 4, 'conv2'], tensor([[[[ 2.1639e-...ackward0>), None)
len(): 34


In [12]:
import torch
A = torch.randn(3,3,3)

def to_list(tensor:torch.Tensor):
    tensor = list(tensor)
    if isinstance(tensor[0],torch.Tensor) and tensor[0].shape.__len__() > 0:
        tensor = [to_list(item) for item in tensor]
    else:
        return [a.item() for a in tensor]
print(to_list(A))

None
