## BitSplit
This notebook demonstrates the implimentation of the paper [Towards accurate post-training network quantization via bit-split and stitching](https://dl.acm.org/doi/abs/10.5555/3524938.3525851)

### Steps to quantize the pretrained model
- Load the dataset and create dataloader. A subset of training data is used for calibration.
- Load the pretrained full precision model.
- Load the configurations from the YAML file.
- Create a `BitSplit` object and pass the full precision model, dataloaders and configurations.
- Quantize the model by calling the `compress_model` method.

In [1]:
import sys
sys.path.append("./trail/trailmet/")

In [3]:
import yaml
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from trailmet.datasets.classification import DatasetFactory
from trailmet.models import resnet, mobilenet
from trailmet.algorithms import quantize

## Datasets

### Augmentations

In [4]:
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats, inplace=True)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

input_transforms = {
    'train': train_transform, 
    'val': val_transform, 
    'test': test_transform}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Datasets

In [5]:
cifar100_dataset = DatasetFactory.create_dataset(
        name = 'CIFAR100', 
        root = './data',
        split_types = ['train', 'val', 'test'],
        val_fraction = 0.2,
        transform = input_transforms,
        target_transform = target_transforms)

# getting the size of the different splits
print('Train samples: ',cifar100_dataset['info']['train_size'])
print('Val samples: ',cifar100_dataset['info']['val_size'])
print('Test samples: ',cifar100_dataset['info']['test_size'] )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train samples:  40000
Val samples:  10000
Test samples:  10000


### Define Dataloaders

In [6]:
train_loader = DataLoader(
        cifar100_dataset['train'], batch_size=128, 
        sampler=cifar100_dataset['train_sampler'],
        num_workers=0)
val_loader = DataLoader(
        cifar100_dataset['val'], batch_size=128, 
        sampler=cifar100_dataset['val_sampler'],
        num_workers=0)
test_loader = DataLoader(
        cifar100_dataset['test'], batch_size=128, 
        sampler=cifar100_dataset['test_sampler'],
        num_workers=0)

dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

print('No. of training batches: ', len(dataloaders['train']))
print('No. of validation batches: ', len(dataloaders['val']))
print('No. of test batches: ', len(dataloaders['test']))

No. of training batches:  313
No. of validation batches:  79
No. of test batches:  79


### Load Pretrained Model

In [7]:
res50_model = resnet.make_resnet50(100,32)
checkpoint = torch.load("./weights/resnet50_cifar100_pretrained.pth", map_location='cuda:0')
res50_model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

### Load Method Config

In [8]:
with open('./bitsplit_config.yaml', 'r') as f:
    config = yaml.safe_load(f)
    kwargs = config['GENERAL']
    assert kwargs['W_BUDGET'] in config['W_ARGS'], 'given weight budget not supported'
    kwargs.update(config['W_ARGS'][kwargs['W_BUDGET']])
    
# This is sample run with reduced batches. 
# Comment out the following lines for best results.
kwargs['CALIB_BATCHES'] = 2

kwargs

{'ARCH': 'ResNet50',
 'DATASET': 'CIFAR100',
 'GPU_ID': 0,
 'SEED': 42,
 'W_BUDGET': 0.125,
 'A_BITS': 8,
 'ACT_QUANT': True,
 'CHANNEL_WISE': True,
 'CALIB_BATCHES': 2,
 'SAVE_PATH': './scales/',
 'LOAD_ACT_SCALES': False,
 'LOAD_WEIGHT_SCALES': False,
 'W_BITS': 4}

### Quantization Method: Bitsplit

In [9]:
quantizer = quantize.bitsplit.BitSplit(res50_model, dataloaders, **kwargs)

print('testing pretrained model before quantization')
acc1, acc5 = quantizer.test(model=res50_model, dataloader=dataloaders['test'], device=torch.device('cuda:0'))
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

qmodel = quantizer.compress_model()

testing pretrained model before quantization


100%|██████████████████████████████████████████████| 79/79 [00:08<00:00,  9.55it/s, acc1=72.5, acc5=91.5]


top-1 acc: 72.52%, top-5 acc: 91.53%
==> Starting weight quantization


100%|██████████████████████████████████████████████████████████████████| 54/54 [1:18:17<00:00, 86.99s/it]


==> Starting '8-bit' activation quantization


100%|█████████████████████████| 50/50 [2:45:04<00:00, 198.10s/it, batch=300/300, prev_layer_scale=0.0145]


In [10]:
print('testing quantized model')
acc1, acc5 = quantizer.test(model=qmodel, dataloader=dataloaders['test'], device=torch.device('cuda:0'))
print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')

testing quantized model


100%|██████████████████████████████████████████████| 79/79 [00:04<00:00, 16.65it/s, acc1=70.1, acc5=89.9]

top-1 acc: 70.07%, top-5 acc: 89.94%



