## BRECQ
This notebook demonstrates the implementation of the paper [BRECQ: Pushing the Limit of Post-Training Quantization by Block Reconstruction](https://arxiv.org/abs/2102.05426)

### Steps to quantize the pretrained model
- Load the dataset and create dataloader. A subset of training data is used for calibration.
- Load the pretrained full precision model.
- Load the configurations from the YAML file.
- Create a `BRECQ` object and pass the full precision model, dataloaders and configurations.
- Quantize the model by calling the `compress_model` method.

In [1]:
import sys
sys.path.append("./trail/trailmet/")

In [3]:
import yaml
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from trailmet.datasets.classification import DatasetFactory
from trailmet.models import resnet, mobilenet
from trailmet.algorithms import quantize

## Datasets

### Augmentations

In [4]:
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats, inplace=True)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

input_transforms = {
    'train': train_transform, 
    'val': val_transform, 
    'test': test_transform}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Datasets

In [5]:
cifar100_dataset = DatasetFactory.create_dataset(
        name = 'CIFAR100', 
        root = './data',
        split_types = ['train', 'val', 'test'],
        val_fraction = 0.2,
        transform = input_transforms,
        target_transform = target_transforms)

# getting the size of the different splits
print('Train samples: ',cifar100_dataset['info']['train_size'])
print('Val samples: ',cifar100_dataset['info']['val_size'])
print('Test samples: ',cifar100_dataset['info']['test_size'] )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train samples:  40000
Val samples:  10000
Test samples:  10000


### Create Dataloaders

In [6]:
train_loader = DataLoader(
        cifar100_dataset['train'], batch_size=128, 
        sampler=cifar100_dataset['train_sampler'],
        num_workers=0)
val_loader = DataLoader(
        cifar100_dataset['val'], batch_size=128, 
        sampler=cifar100_dataset['val_sampler'],
        num_workers=0)
test_loader = DataLoader(
        cifar100_dataset['test'], batch_size=128, 
        sampler=cifar100_dataset['test_sampler'],
        num_workers=0)

dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

print('No. of training batches: ', len(dataloaders['train']))
print('No. of validation batches: ', len(dataloaders['val']))
print('No. of test batches: ', len(dataloaders['test']))

No. of training batches:  313
No. of validation batches:  79
No. of test batches:  79


### Load Pretrained Model

In [7]:
r50_model = resnet.make_resnet50(100,32)
checkpoint = torch.load("./weights/resnet50_cifar100_pretrained.pth", map_location='cuda:0')
r50_model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

### Load Method Config

In [8]:
with open('./brecq_config.yaml', 'r') as f:
    config = yaml.safe_load(f)
    kwargs = config['GENERAL']
    assert kwargs['W_BUDGET'] in config['W_ARGS'], 'given weight budget not supported'
    kwargs.update(config['W_ARGS'][kwargs['W_BUDGET']])

# This is sample run with reduced samples and iterations. 
# Comment out the following lines for best results.
kwargs['NUM_SAMPLES'] = 128
kwargs['ITERS_W'] = 1000
kwargs['ITERS_A'] = 1000

kwargs

{'ARCH': 'ResNet50',
 'DATASET': 'CIFAR100',
 'GPU_ID': 0,
 'SEED': 42,
 'W_BUDGET': 0.125,
 'A_BITS': 8,
 'ACT_QUANT': True,
 'CHANNEL_WISE': True,
 'NUM_SAMPLES': 128,
 'ITERS_A': 1000,
 'OPTIMIZER': 'adam',
 'SAVE_PATH': './scales/',
 'W_BITS': 4,
 'ITERS_W': 1000}

### Quantization Method: BRECQ

In [9]:
quantizer = quantize.brecq.BRECQ(r50_model, dataloaders, **kwargs)

print('testing pretrained model before quantization')
acc1, acc5 = quantizer.test(model=r50_model, dataloader=dataloaders['test'], device=torch.device('cuda:0'))
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

qmodel = quantizer.compress_model()

==> Using seed : 42
testing pretrained model before quantization


100%|██████████████████████████████████████████████| 79/79 [00:09<00:00,  8.18it/s, acc1=72.5, acc5=91.5]


top-1 acc: 72.52%, top-5 acc: 91.53%
==> Initializing weight quantization parameters


100%|██████████████████████████████████████████████| 79/79 [00:05<00:00, 13.96it/s, acc1=70.7, acc5=90.7]


Quantized accuracy before brecq: (70.69818037974683, 90.65466772151899)
==> Starting weight calibration
Reconstruction for layer conv1


100%|███████████████████████████████████████████████| 1000/1000 [00:03<00:00, 259.26it/s, b=2, loss=2.66]


Reconstruction for block 0


100%|█████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 98.60it/s, b=2, loss=111]


Reconstruction for block 1


100%|███████████████████████████████████████████████| 1000/1000 [00:08<00:00, 116.83it/s, b=2, loss=96.6]


Reconstruction for block 2


100%|███████████████████████████████████████████████| 1000/1000 [00:08<00:00, 123.69it/s, b=2, loss=75.2]


Reconstruction for block 0


100%|█████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 69.36it/s, b=2, loss=537]


Reconstruction for block 1


100%|█████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 82.10it/s, b=2, loss=375]


Reconstruction for block 2


100%|█████████████████████████████████████████████████| 1000/1000 [00:11<00:00, 86.03it/s, b=2, loss=385]


Reconstruction for block 3


100%|████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 103.58it/s, b=2, loss=375]


Reconstruction for block 0


100%|█████████████████████████████████████████████| 1000/1000 [00:10<00:00, 92.88it/s, b=2, loss=2.16e+3]


Reconstruction for block 1


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 118.84it/s, b=2, loss=1.54e+3]


Reconstruction for block 2


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 117.26it/s, b=2, loss=1.51e+3]


Reconstruction for block 3


100%|████████████████████████████████████████████| 1000/1000 [00:09<00:00, 107.57it/s, b=2, loss=1.49e+3]


Reconstruction for block 4


100%|████████████████████████████████████████████| 1000/1000 [00:09<00:00, 109.70it/s, b=2, loss=1.47e+3]


Reconstruction for block 5


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 115.09it/s, b=2, loss=1.47e+3]


Reconstruction for block 0


100%|████████████████████████████████████████████| 1000/1000 [00:09<00:00, 101.63it/s, b=2, loss=9.01e+3]


Reconstruction for block 1


100%|█████████████████████████████████████████████| 1000/1000 [00:07<00:00, 129.27it/s, b=2, loss=7.2e+3]


Reconstruction for block 2


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 115.06it/s, b=2, loss=6.98e+3]


Reconstruction for layer fc


100%|████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 280.20it/s, b=2, loss=504]
100%|██████████████████████████████████████████████| 79/79 [00:05<00:00, 13.82it/s, acc1=72.5, acc5=91.4]


Weight quantization accuracy: (72.52768987341773, 91.40625)
Reconstruction for layer conv1


100%|███████████████████████████████████████████| 1000/1000 [00:03<00:00, 295.97it/s, b=0, loss=0.000386]


Reconstruction for block 0


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 116.17it/s, b=0, loss=0.00152]


Reconstruction for block 1


100%|████████████████████████████████████████████| 1000/1000 [00:07<00:00, 127.02it/s, b=0, loss=0.00016]


Reconstruction for block 2


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 122.08it/s, b=0, loss=4.88e-5]


Reconstruction for block 0


100%|████████████████████████████████████████████| 1000/1000 [00:07<00:00, 127.18it/s, b=0, loss=0.00106]


Reconstruction for block 1


100%|███████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.98it/s, b=0, loss=0.000216]


Reconstruction for block 2


100%|███████████████████████████████████████████| 1000/1000 [00:08<00:00, 122.01it/s, b=0, loss=0.000211]


Reconstruction for block 3


100%|███████████████████████████████████████████| 1000/1000 [00:07<00:00, 125.68it/s, b=0, loss=0.000184]


Reconstruction for block 0


100%|████████████████████████████████████████████| 1000/1000 [00:09<00:00, 108.87it/s, b=0, loss=0.00173]


Reconstruction for block 1


100%|███████████████████████████████████████████| 1000/1000 [00:07<00:00, 130.99it/s, b=0, loss=0.000682]


Reconstruction for block 2


100%|███████████████████████████████████████████| 1000/1000 [00:07<00:00, 132.14it/s, b=0, loss=0.000355]


Reconstruction for block 3


100%|███████████████████████████████████████████| 1000/1000 [00:08<00:00, 121.54it/s, b=0, loss=0.000216]


Reconstruction for block 4


100%|███████████████████████████████████████████| 1000/1000 [00:07<00:00, 125.02it/s, b=0, loss=0.000191]


Reconstruction for block 5


100%|███████████████████████████████████████████| 1000/1000 [00:08<00:00, 124.69it/s, b=0, loss=0.000227]


Reconstruction for block 0


100%|████████████████████████████████████████████| 1000/1000 [00:08<00:00, 112.77it/s, b=0, loss=0.00929]


Reconstruction for block 1


100%|█████████████████████████████████████████████| 1000/1000 [00:07<00:00, 132.72it/s, b=0, loss=0.0242]


Reconstruction for block 2


100%|█████████████████████████████████████████████| 1000/1000 [00:07<00:00, 125.60it/s, b=0, loss=0.0196]


Reconstruction for layer fc


100%|█████████████████████████████████████████████| 1000/1000 [00:02<00:00, 430.09it/s, b=0, loss=0.0351]
100%|██████████████████████████████████████████████| 79/79 [00:07<00:00, 11.23it/s, acc1=72.5, acc5=91.4]

Full quantization (W4A8) accuracy: (72.4881329113924, 91.39636075949367)





In [10]:
print('testing quantized model')
acc1, acc5 = quantizer.test(model=qmodel, dataloader=dataloaders['test'], device=torch.device('cuda:0'))
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

testing quantized model


100%|██████████████████████████████████████████████| 79/79 [00:07<00:00, 10.66it/s, acc1=72.5, acc5=91.4]

top-1 acc: 72.49%, top-5 acc: 91.40%



