## Tutorial 2. Employing OTO onto ResNet18 CIFAR10. 


In this tutorial, we will show 

- How to end-to-end train and compress a ResNet18 from scratch on CIFAR10 to get a compressed ResNet18.
- The compressed ResNet18 achives both **high performance** and **significant FLOPs and parameters reductions** than the full model. 
- More detailed DHSPG optimizer setup.


### Step 1. Create OTO instance

In [1]:
import torch
from backends import resnet18_cifar10
from only_train_once import OTO

model = resnet18_cifar10()
dummy_input = torch.zeros(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions

In [2]:
# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.
oto.visualize_zigs()

### Step 2. Dataset Preparation

In [3]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))

trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


### Step 3. Setup DHSPG optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. 
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `start_pruning_steps`: The number of steps that start to prune. 
- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration.

In [4]:
optimizer = oto.dhspg(
    variant='sgd', 
    lr=0.1, 
    target_group_sparsity=0.7,
    warm_up_steps=50,
    weight_decay=1e-4,
    start_pruning_steps=50 * len(trainloader), # start pruning after 50 epochs
    epsilon=0.95)

### Step 4. Train ResNet18 as normal.

In [5]:
from utils.utils import check_accuracy

max_epoch = 300
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
# Every 75 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1) 

for epoch in range(max_epoch):
    f_avg_val = 0.0
    model.train()
    lr_scheduler.step()
    for X, y in trainloader:
        X = X.cuda()
        y = y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    group_sparsity, omega = optimizer.compute_group_sparsity_omega()
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    print("Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}".format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))



Epoch: 0, loss: 1.67, omega:4122.27, group_sparsity: 0.00, acc1: 0.3347
Epoch: 1, loss: 1.15, omega:4112.58, group_sparsity: 0.00, acc1: 0.4246
Epoch: 2, loss: 0.88, omega:4103.96, group_sparsity: 0.00, acc1: 0.5233
Epoch: 3, loss: 0.71, omega:4094.89, group_sparsity: 0.00, acc1: 0.6019
Epoch: 4, loss: 0.60, omega:4084.36, group_sparsity: 0.00, acc1: 0.7093
Epoch: 5, loss: 0.52, omega:4073.36, group_sparsity: 0.00, acc1: 0.7460
Epoch: 6, loss: 0.47, omega:4062.06, group_sparsity: 0.00, acc1: 0.7607
Epoch: 7, loss: 0.43, omega:4050.87, group_sparsity: 0.00, acc1: 0.7199
Epoch: 8, loss: 0.39, omega:4039.35, group_sparsity: 0.00, acc1: 0.7358
Epoch: 9, loss: 0.36, omega:4028.00, group_sparsity: 0.00, acc1: 0.4426
Epoch: 10, loss: 0.33, omega:4015.55, group_sparsity: 0.00, acc1: 0.8550
Epoch: 11, loss: 0.30, omega:4003.24, group_sparsity: 0.00, acc1: 0.8667
Epoch: 12, loss: 0.28, omega:3990.84, group_sparsity: 0.00, acc1: 0.8558
Epoch: 13, loss: 0.27, omega:3978.91, group_sparsity: 0.00, a

### Step 5. Get compressed model in ONNX format

In [6]:
# A ResNet_compressed.onnx will be generated. 
oto.compress()

### (Optional) Compute FLOPs and number of parameters before and after OTO training

The compressed ResNet18 only uses 9% of the parameters about 20% of the FLOPs compared to the full model. 

In [7]:
full_flops = oto.compute_flops()
compressed_flops = oto.compute_flops(compressed=True)
full_num_params = oto.compute_num_params()
compressed_num_params = oto.compute_num_params(compressed=True)

print("Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))
print("Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))

Full FLOPs (M): 555.42. Compressed FLOPs (M): 112.52. Reduction Ratio: 0.7974
Full # Params: 11173962. Compressed # Params: 981808. Reduction Ratio: 0.9121


### (Optional) Check the compressed model accuracy

#### Both full and compressed model should return the exact same accuracy.

Compared to the baseline of full ResNet18 on CIFAR10, the compressed ResNet18 only regresses 0.1% top-1 accuracy, but significantly reduces FLOPs and params.

In [8]:
from utils.utils import check_accuracy_onnx
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

acc1_full, acc5_full = check_accuracy(model, testloader)
print("Full model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_full, acc5=acc5_full))

acc1_compressed, acc5_compressed = check_accuracy_onnx(oto.compressed_model_path, testloader)
print("Compressed model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_compressed, acc5=acc5_compressed))

Full model: Acc 1: 0.9286, Acc 5: 0.9974
Compressed model: Acc 1: 0.9286, Acc 5: 0.9974
