## Tutorial 3. ResNet50 on CIFAR10. 


In this tutorial, we will show 

- How to end-to-end train and compress a ResNet50 on CIFAR10 to reproduce the results shown in the paper.

### Step 1. Create OTO instance

In [1]:
import torch
from only_train_once import OTO
from backends import resnet50_cifar10

model = resnet50_cifar10().cuda()
dummy_input = torch.zeros(1, 3, 32, 32).cuda()
oto = OTO(model=model, dummy_input=dummy_input)

graph constructor
grow_non_stem_connected_components
group_individual_nodes


#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions

- Set `view` as `False` if no browser is accessiable.
- Open the generated pdf file instead.

In [2]:
# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.
oto.visualize_zigs(view=False)

### Step 2. Dataset Preparation

In [3]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))

trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


### Step 3. Setup DHSPG optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. 
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `start_pruning_steps`: The number of steps that start to prune. 
- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration.

In [4]:
optimizer = oto.dhspg(
    variant='sgd', 
    lr=0.1, 
    target_group_sparsity=0.8,
    weight_decay=1e-3, # Weigth decay is important for ResNet50 on CIFAR10
    start_pruning_steps=10 * len(trainloader), # start pruning after 10 epochs. Start pruning at initialization stage.
    epsilon=0.95)

### Step 4. Train ResNet50 as normal.

Add functions related to `label smoothing` and `mix-up`.

#### Start official training and compression.

Under $40\%$ group sparsity, the top-1 accuracy could reach $75.2\%$ on the specific run.

During our experiments, the top-1 accuracy is $(75.2\pm0.1)\%$ upon different random seeds.

In [5]:
from tqdm import tqdm 
from utils.utils import check_accuracy

max_epoch = 300
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
# Every 75 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1) 

for epoch in range(max_epoch):
    f_avg_val = 0.0
    model.train()
    lr_scheduler.step()
    for X, y in tqdm(trainloader):
        X = X.cuda()
        y = y.cuda()                
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    group_sparsity, omega = optimizer.compute_group_sparsity_omega()
    # Record the below four metrics after pruning step if you want to trouble shoot in depth.
    # norm_redundant, norm_important, num_groups_redundant, num_groups_important = optimizer.compute_norm_group_partitions()
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    print("Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}".\
          format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))
    # print("Epoch: {ep}, norm_redundant: {norm_redundant:.2f}, norm_important:{norm_important:.2f}, num_groups_redundant: {num_groups_redundant}, num_groups_important: {num_groups_important}".\
    #       format(ep=epoch, norm_redundant=norm_redundant, norm_important=norm_important, num_groups_redundant=num_groups_redundant, num_groups_important=num_groups_important))

    # Save model checkpoint
    # torch.save(model, CKPT_PATH)

Epoch: 0, loss: 2.00, omega:17252.28, group_sparsity: 0.00, acc1: 0.3472


Epoch: 1, loss: 1.35, omega:15983.06, group_sparsity: 0.00, acc1: 0.4438


Epoch: 2, loss: 1.00, omega:14818.10, group_sparsity: 0.00, acc1: 0.5120


Epoch: 3, loss: 0.80, omega:13744.97, group_sparsity: 0.00, acc1: 0.6011


Epoch: 4, loss: 0.67, omega:12755.82, group_sparsity: 0.00, acc1: 0.7155


Epoch: 5, loss: 0.59, omega:11843.66, group_sparsity: 0.00, acc1: 0.5383


Epoch: 6, loss: 0.54, omega:11002.42, group_sparsity: 0.00, acc1: 0.7137


Epoch: 7, loss: 0.50, omega:10227.55, group_sparsity: 0.00, acc1: 0.7381


Epoch: 8, loss: 0.47, omega:9514.79, group_sparsity: 0.00, acc1: 0.6899


Epoch: 9, loss: 0.44, omega:8860.10, group_sparsity: 0.00, acc1: 0.6625


partition_groups


Epoch: 10, loss: 0.45, omega:7330.37, group_sparsity: 0.00, acc1: 0.5874


Epoch: 11, loss: 0.51, omega:5844.28, group_sparsity: 0.00, acc1: 0.6718


Epoch: 12, loss: 0.60, omega:4443.87, group_sparsity: 0.05, acc1: 0.4572


Epoch: 13, loss: 0.66, omega:3331.90, group_sparsity: 0.23, acc1: 0.6489


Epoch: 14, loss: 0.64, omega:2589.54, group_sparsity: 0.42, acc1: 0.7427


Epoch: 15, loss: 0.58, omega:2053.05, group_sparsity: 0.56, acc1: 0.7016


Epoch: 16, loss: 0.56, omega:1552.30, group_sparsity: 0.60, acc1: 0.6447


Epoch: 17, loss: 0.58, omega:1183.77, group_sparsity: 0.75, acc1: 0.5474


Epoch: 18, loss: 0.60, omega:1054.18, group_sparsity: 0.80, acc1: 0.7526


Epoch: 19, loss: 0.54, omega:1079.29, group_sparsity: 0.80, acc1: 0.6060


......


Epoch: 292, loss: 0.01, omega:733.84, group_sparsity: 0.80, acc1: 0.9434


Epoch: 293, loss: 0.01, omega:733.79, group_sparsity: 0.80, acc1: 0.9425


Epoch: 294, loss: 0.01, omega:733.73, group_sparsity: 0.80, acc1: 0.9421


Epoch: 295, loss: 0.01, omega:733.68, group_sparsity: 0.80, acc1: 0.9417


Epoch: 296, loss: 0.01, omega:733.62, group_sparsity: 0.80, acc1: 0.9435


Epoch: 297, loss: 0.01, omega:733.57, group_sparsity: 0.80, acc1: 0.9433


Epoch: 298, loss: 0.01, omega:733.51, group_sparsity: 0.80, acc1: 0.9441


Epoch: 299, loss: 0.01, omega:733.51, group_sparsity: 0.80, acc1: 0.9439


### Step 5. Get compressed model in ONNX format

By default, OTO will compress the last checkpoint. 

If we want to compress another checkpoint, need to reinitialize OTO then compress

    oto = OTO(model=torch.load(ckpt_path), dummy_input)

In [6]:
# A ResNet_compressed.onnx will be generated. 
oto.compress()

### (Optional) Compute FLOPs and number of parameters before and after OTO training

The compressed ResNet50 under 80% group sparsity reduces FLOPs by 90.0% and parameters by 96.4%.

In [7]:
full_flops = oto.compute_flops()
compressed_flops = oto.compute_flops(compressed=True)
full_num_params = oto.compute_num_params()
compressed_num_params = oto.compute_num_params(compressed=True)

print("Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))
print("Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))

Full FLOPs (M): 1297.83. Compressed FLOPs (M): 129.29. Reduction Ratio: 0.9004
Full # Params: 23520842. Compressed # Params: 837108.. Reduction Ratio: 0.9644


### (Optional) Check the compressed model accuracy

#### Both full and compressed model should return the exact same accuracy.

Compared to the [baseline of full ResNet50 on CIFAR10 (93.62% top-1 accuracy)](https://github.com/kuangliu/pytorch-cifar), the compressed ResNet50 improves top-1 accuracy by about 1%, but significantly reduces FLOPs and params.

In [None]:
from utils.utils import check_accuracy_onnx
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

acc1_full, acc5_full = check_accuracy(model, testloader)
print("Full model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_full, acc5=acc5_full))

acc1_compressed, acc5_compressed = check_accuracy_onnx(oto.compressed_model_path, testloader)
print("Compressed model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_compressed, acc5=acc5_compressed))

Full model: Acc 1: 0.9439, Acc 5: 0.9976Compressed model: Acc 1: 0.9439, Acc 5: 0.9976