## BRECQ
This notebook demonstrates the implementation of the paper [BRECQ: Pushing the Limit of Post-Training Quantization by Block Reconstruction](https://arxiv.org/abs/2102.05426)

### Steps to quantize the pretrained model
- Load the dataset and create dataloader. A subset of training data is used for calibration.
- Load the pretrained full precision model.
- Load the configurations from the YAML file.
- Create a `BRECQ` object and pass the full precision model, dataloaders and configurations.
- Quantize the model by calling the `compress_model` method.

In [3]:
import sys
sys.path.append("../../../")

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [4]:
import yaml
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from trailmet.datasets.classification import DatasetFactory
from trailmet.models import ModelsFactory
from trailmet.algorithms import quantize

## Datasets

### Augmentations

In [5]:
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats, inplace=True)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

input_transforms = {
    'train': train_transform, 
    'val': val_transform, 
    'test': test_transform}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Datasets

In [6]:
cifar100_dataset = DatasetFactory.create_dataset(
        name = 'CIFAR100', 
        root = './data',
        split_types = ['train', 'val', 'test'],
        val_fraction = 0.2,
        transform = input_transforms,
        target_transform = target_transforms)

# getting the size of the different splits
print('Train samples: ',cifar100_dataset['info']['train_size'])
print('Val samples: ',cifar100_dataset['info']['val_size'])
print('Test samples: ',cifar100_dataset['info']['test_size'] )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train samples:  40000
Val samples:  10000
Test samples:  10000


### Create Dataloaders

In [7]:
train_loader = DataLoader(
        cifar100_dataset['train'], batch_size=128, 
        sampler=cifar100_dataset['train_sampler'],
        num_workers=2)
val_loader = DataLoader(
        cifar100_dataset['val'], batch_size=128, 
        sampler=cifar100_dataset['val_sampler'],
        num_workers=2)
test_loader = DataLoader(
        cifar100_dataset['test'], batch_size=128, 
        sampler=cifar100_dataset['test_sampler'],
        num_workers=2)

dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

print('No. of training batches: ', len(dataloaders['train']))
print('No. of validation batches: ', len(dataloaders['val']))
print('No. of test batches: ', len(dataloaders['test']))

No. of training batches:  313
No. of validation batches:  79
No. of test batches:  79


### Load Model

In [8]:
res50_model = ModelsFactory.create_model(name='resnet50', num_classes=100, pretrained=False, insize=32)

### Load Method Config

In [9]:
with open('./brecq_config.yaml', 'r') as f:
    config = yaml.safe_load(f)
    kwargs = config['GENERAL']
    assert kwargs['W_BUDGET'] in config['W_ARGS'], 'given weight budget not supported'
    kwargs.update(config['W_ARGS'][kwargs['W_BUDGET']])

# This is sample run with reduced samples and iterations. 
# Comment out the following lines for best results.
kwargs['NUM_SAMPLES'] = 128
kwargs['ITERS_W'] = 1000
kwargs['ITERS_A'] = 1000

kwargs

{'ARCH': 'ResNet50',
 'DATASET': 'CIFAR100',
 'GPU_ID': 0,
 'SEED': 42,
 'W_BUDGET': 0.125,
 'A_BITS': 8,
 'ACT_QUANT': True,
 'CHANNEL_WISE': True,
 'NUM_SAMPLES': 128,
 'ITERS_A': 1000,
 'WEIGHT': 0.01,
 'LR': 0.0004,
 'OPTIMIZER': 'adam',
 'SAVE_PATH': './scales/',
 'W_BITS': 4,
 'ITERS_W': 1000}

### Quantization Method: BRECQ

In [12]:
quantizer = quantize.brecq.BRECQ(res50_model, dataloaders, **kwargs)

print('testing pretrained model before quantization')
_, acc1, acc5 = quantizer.test(model=res50_model, dataloader=dataloaders['test'], loss_fn=torch.nn.CrossEntropyLoss())
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

qmodel = quantizer.compress_model()

==> Using seed : 42


testing pretrained model before quantization


Validating network (79 / 79 Steps) (batch time=0.01762s) (loss=9.17796) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 15.80it/s] 


 * acc@1 1.040 acc@5 5.190
top-1 acc: 1.04%, top-5 acc: 5.19%
==> Initializing weight quantization parameters


Validating network (79 / 79 Steps) (batch time=0.02110s) (loss=8.81658) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:02<00:00, 31.80it/s] 


 * acc@1 1.300 acc@5 5.130
Quantized accuracy before brecq: 1.2999999523162842
==> Starting weight calibration
Reconstruction for layer conv1


Reconstructing Layer: Loss (3.518) b (2.0): 100%|| 1000/1000 [00:02<00:00, 382.29it/s] 


Reconstruction for block 0


Reconstructing Block: Loss (131.894) b (2.0): 100%|| 1000/1000 [00:06<00:00, 149.93it/s]


Reconstruction for block 1


Reconstructing Block: Loss (137.555) b (2.0): 100%|| 1000/1000 [00:05<00:00, 186.71it/s]


Reconstruction for block 2


Reconstructing Block: Loss (148.040) b (2.0): 100%|| 1000/1000 [00:05<00:00, 197.38it/s]


Reconstruction for block 0


Reconstructing Block: Loss (750.630) b (2.0): 100%|| 1000/1000 [00:07<00:00, 138.81it/s] 


Reconstruction for block 1


Reconstructing Block: Loss (610.501) b (2.0): 100%|| 1000/1000 [00:05<00:00, 184.58it/s] 


Reconstruction for block 2


Reconstructing Block: Loss (679.711) b (2.0): 100%|| 1000/1000 [00:05<00:00, 184.35it/s] 


Reconstruction for block 3


Reconstructing Block: Loss (737.091) b (2.0): 100%|| 1000/1000 [00:05<00:00, 177.89it/s] 


Reconstruction for block 0


Reconstructing Block: Loss (3709.193) b (2.0): 100%|| 1000/1000 [00:06<00:00, 148.69it/s] 


Reconstruction for block 1


Reconstructing Block: Loss (3016.720) b (2.0): 100%|| 1000/1000 [00:05<00:00, 185.82it/s] 


Reconstruction for block 2


Reconstructing Block: Loss (3416.183) b (2.0): 100%|| 1000/1000 [00:05<00:00, 185.97it/s] 


Reconstruction for block 3


Reconstructing Block: Loss (3782.748) b (2.0): 100%|| 1000/1000 [00:05<00:00, 191.07it/s] 


Reconstruction for block 4


Reconstructing Block: Loss (4219.022) b (2.0): 100%|| 1000/1000 [00:05<00:00, 169.18it/s] 


Reconstruction for block 5


Reconstructing Block: Loss (4588.935) b (2.0): 100%|| 1000/1000 [00:05<00:00, 169.72it/s] 


Reconstruction for block 0


Reconstructing Block: Loss (21842.725) b (2.0): 100%|| 1000/1000 [00:07<00:00, 141.11it/s]


Reconstruction for block 1


Reconstructing Block: Loss (16233.676) b (2.0): 100%|| 1000/1000 [00:05<00:00, 177.23it/s]


Reconstruction for block 2


Reconstructing Block: Loss (17142.783) b (2.0): 100%|| 1000/1000 [00:05<00:00, 185.41it/s]


Reconstruction for layer fc


Reconstructing Layer: Loss (808.284) b (2.0): 100%|| 1000/1000 [00:02<00:00, 462.91it/s] 
Validating network (79 / 79 Steps) (batch time=0.02218s) (loss=9.12264) (top1=0.00000) (top5=6.25000): 100%|| 79/79 [00:02<00:00, 31.10it/s] 


 * acc@1 1.410 acc@5 5.040
Weight quantization accuracy: 1.409999966621399
Reconstruction for layer conv1


Reconstructing Layer: Loss (0.001) b (0.0): 100%|| 1000/1000 [00:02<00:00, 481.38it/s]


Reconstruction for block 0


Reconstructing Block: Loss (0.012) b (0.0): 100%|| 1000/1000 [00:05<00:00, 194.20it/s]


Reconstruction for block 1


Reconstructing Block: Loss (0.036) b (0.0): 100%|| 1000/1000 [00:04<00:00, 241.08it/s]


Reconstruction for block 2


Reconstructing Block: Loss (0.095) b (0.0): 100%|| 1000/1000 [00:03<00:00, 250.58it/s]


Reconstruction for block 0


Reconstructing Block: Loss (0.170) b (0.0): 100%|| 1000/1000 [00:05<00:00, 179.23it/s]


Reconstruction for block 1


Reconstructing Block: Loss (0.329) b (0.0): 100%|| 1000/1000 [00:04<00:00, 216.24it/s]


Reconstruction for block 2


Reconstructing Block: Loss (1.400) b (0.0): 100%|| 1000/1000 [00:04<00:00, 207.09it/s]


Reconstruction for block 3


Reconstructing Block: Loss (3.608) b (0.0): 100%|| 1000/1000 [00:04<00:00, 221.21it/s]


Reconstruction for block 0


Reconstructing Block: Loss (10.746) b (0.0): 100%|| 1000/1000 [00:04<00:00, 205.43it/s]


Reconstruction for block 1


Reconstructing Block: Loss (19.581) b (0.0): 100%|| 1000/1000 [00:04<00:00, 235.84it/s]


Reconstruction for block 2


Reconstructing Block: Loss (53.050) b (0.0): 100%|| 1000/1000 [00:04<00:00, 235.11it/s]


Reconstruction for block 3


Reconstructing Block: Loss (193.496) b (0.0): 100%|| 1000/1000 [00:04<00:00, 217.34it/s]


Reconstruction for block 4


Reconstructing Block: Loss (516.756) b (0.0): 100%|| 1000/1000 [00:05<00:00, 181.13it/s]


Reconstruction for block 5


Reconstructing Block: Loss (1702.890) b (0.0): 100%|| 1000/1000 [00:05<00:00, 182.47it/s]


Reconstruction for block 0


Reconstructing Block: Loss (3935.217) b (0.0): 100%|| 1000/1000 [00:06<00:00, 150.33it/s]


Reconstruction for block 1


Reconstructing Block: Loss (7556.854) b (0.0): 100%|| 1000/1000 [00:05<00:00, 175.66it/s]


Reconstruction for block 2


Reconstructing Block: Loss (8373.395) b (0.0): 100%|| 1000/1000 [00:05<00:00, 173.69it/s]


Reconstruction for layer fc


Reconstructing Layer: Loss (16.654) b (0.0): 100%|| 1000/1000 [00:01<00:00, 804.19it/s]
Validating network (79 / 79 Steps) (batch time=0.03732s) (loss=9.09431) (top1=0.00000) (top5=6.25000): 100%|| 79/79 [00:03<00:00, 21.76it/s] 

 * acc@1 1.400 acc@5 5.040
Full quantization (W4A8) accuracy: 1.399999976158142





In [13]:
print('testing quantized model')
_, acc1, acc5 = quantizer.test(model=qmodel, dataloader=dataloaders['test'], loss_fn=torch.nn.CrossEntropyLoss())
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

testing quantized model


Validating network (79 / 79 Steps) (batch time=0.03837s) (loss=9.09431) (top1=0.00000) (top5=6.25000): 100%|| 79/79 [00:03<00:00, 23.16it/s] 

 * acc@1 1.400 acc@5 5.040
top-1 acc: 1.40%, top-5 acc: 5.04%



