## Tutorial 4. VGG16-BN on CIFAR10. 


In this tutorial, we will show 

- How to end-to-end train and compress a VGG16-BN on CIFAR10 to reproduce the results shown in the paper.
- Please ensure the `only_train_once` version is `>=2.0.16`.


### Step 1. Create OTO instance

In [1]:
import torch
import random
import numpy as np
from only_train_once import OTO
from backends import vgg16_bn

# Set up random seed, experimental results may vary upon GPUs, CUDA as well.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

model = vgg16_bn().cuda()
dummy_input = torch.zeros(1, 3, 32, 32).cuda()
oto = OTO(model=model, dummy_input=dummy_input)

graph constructor
grow_non_stem_connected_components
group_individual_nodes


#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions

- Set `view` as `False` if no browser is accessiable.
- Open the generated pdf file instead.

In [2]:
# A vgg16.gv.pdf will be generated to display the depandancy graph.
oto.visualize_zigs(view=False)

### Step 2. Dataset Preparation

In [3]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))

trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


### Step 3. Setup DHSPG optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. 
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `tolerance_group_sparsity`: The percentage of groups that additionally want to feed into Half-space projection. (Special for VGG16 experiments).
- `start_pruning_steps`: The number of steps that start to prune. 
- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration.

In [4]:
optimizer = oto.dhspg(
    variant='sgd', 
    lr=0.1, 
    target_group_sparsity=0.79,
    tolerance_group_sparsity=0.21, 
    lmbda=5e-4,
    weight_decay=5e-4, 
    weight_decay_type='l1_norm', 
    start_pruning_steps=50 * len(trainloader), # start pruning after 10 epochs. Start pruning at initialization stage.
    epsilon=0.9)

### Step 4. Train VGG16 as normal.

Add function related to `mix_up`.

In [5]:
import numpy as np
import torch.nn.functional as F

def one_hot(y, num_classes, smoothing_eps=None):
    if smoothing_eps is None:
        one_hot_y = F.one_hot(y, num_classes).float()
        return one_hot_y
    else:
        one_hot_y = F.one_hot(y, num_classes).float()
        v1 = 1 - smoothing_eps + smoothing_eps / float(num_classes)
        v0 = smoothing_eps / float(num_classes)
        new_y = one_hot_y * (v1 - v0) + v0
        return new_y

def mixup_func(input, target, alpha=0.2):
    gamma = np.random.beta(alpha, alpha)
    # target is onehot format!
    perm = torch.randperm(input.size(0))
    perm_input = input[perm]
    perm_target = target[perm]
    return input.mul_(gamma).add_(1 - gamma, perm_input), target.mul_(gamma).add_(1 - gamma, perm_target)

#### Start official training and compression.

The top-1 accuracy could reach $93.3\%$ on the specific run.

Experiments of VGG16 fluctuates a bit more than others, perhaps due to its straight architecture.

During our experiments, the top-1 accuracy is $93.1\pm0.3\%$ upon different random seeds and GPUs.

In [6]:
from tqdm import tqdm 
from utils.utils import check_accuracy

num_classes = 10

mix_up = True
max_epoch = 300
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
# Every 75 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1) 

for epoch in range(max_epoch):
    f_avg_val = 0.0
    model.train()
    lr_scheduler.step()
    for X, y in trainloader:
        X = X.cuda()
        y = y.cuda()   
        with torch.no_grad():
            if mix_up:
                y = one_hot(y, num_classes=num_classes)
                X, y = mixup_func(X, y)
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    group_sparsity, omega = optimizer.compute_group_sparsity_omega()
    # Record the below four metrics after pruning step if you want to trouble shoot in depth.
    # norm_redundant, norm_important, num_groups_redundant, num_groups_important = optimizer.compute_norm_group_partitions()
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    print("Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}".\
          format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))
    # print("Epoch: {ep}, norm_redundant: {norm_redundant:.2f}, norm_important:{norm_important:.2f}, num_groups_redundant: {num_groups_redundant}, num_groups_important: {num_groups_important}".\
    #       format(ep=epoch, norm_redundant=norm_redundant, norm_important=norm_important, num_groups_redundant=num_groups_redundant, num_groups_important=num_groups_important))

    # Save model checkpoint
    # torch.save(model, CKPT_PATH)

Epoch: 0, loss: 1.79, omega:7314.65, group_sparsity: 0.00, acc1: 0.4082


Epoch: 1, loss: 1.44, omega:7061.52, group_sparsity: 0.00, acc1: 0.5302


Epoch: 2, loss: 1.26, omega:6818.39, group_sparsity: 0.00, acc1: 0.5962


Epoch: 3, loss: 1.14, omega:6585.05, group_sparsity: 0.00, acc1: 0.6300


Epoch: 4, loss: 1.06, omega:6360.30, group_sparsity: 0.00, acc1: 0.7180


Epoch: 5, loss: 1.02, omega:6144.53, group_sparsity: 0.00, acc1: 0.6150


......


Epoch: 295, loss: 0.49, omega:667.54, group_sparsity: 0.79, acc1: 0.9323


Epoch: 296, loss: 0.48, omega:667.58, group_sparsity: 0.79, acc1: 0.9332


Epoch: 297, loss: 0.50, omega:667.56, group_sparsity: 0.79, acc1: 0.9311


Epoch: 298, loss: 0.49, omega:667.54, group_sparsity: 0.79, acc1: 0.9321


Epoch: 299, loss: 0.49, omega:667.54, group_sparsity: 0.79, acc1: 0.9327


### Step 5. Get compressed model in ONNX format

By default, OTO will compress the last checkpoint. 

If we want to compress another checkpoint, need to reinitialize OTO then compress

    oto = OTO(model=torch.load(ckpt_path), dummy_input)

In [7]:
# A VGG16_compressed.onnx will be generated. 
oto.compress()

### (Optional) Compute FLOPs and number of parameters before and after OTO training

The compressed VGG16-BN under about 80% group sparsity reduces FLOPs by 73.4% and parameters by 95.0% on this specific run.

In [8]:
full_flops = oto.compute_flops()
compressed_flops = oto.compute_flops(compressed=True)
full_num_params = oto.compute_num_params()
compressed_num_params = oto.compute_num_params(compressed=True)

print("Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))
print("Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))

Full FLOPs (M): 313.73. Compressed FLOPs (M): 83.44. Reduction Ratio: 0.7340
Full # Params: 15253578. Compressed # Params: 766831. Reduction Ratio: 0.9497


### (Optional) Check the compressed model accuracy

Both full and compressed model should return the exact same accuracy.

In [9]:
from utils.utils import check_accuracy_onnx
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

acc1_full, acc5_full = check_accuracy(model, testloader)
print("Full model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_full, acc5=acc5_full))

acc1_compressed, acc5_compressed = check_accuracy_onnx(oto.compressed_model_path, testloader)
print("Compressed model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_compressed, acc5=acc5_compressed))

Full model: Acc 1: 0.9327, Acc 5: 0.9965
Compressed model: Acc 1: 0.9327, Acc 5: 0.9965
