## Tutorial 2. ResNet50 on ImageNet (2012). 


In this tutorial, we will show 

- How to end-to-end train and compress a ResNet50 on ImageNet to reproduce the results shown in the paper.

### Step 1. Create OTO instance

In [None]:
import torch
from only_train_once import OTO
import torchvision

model = torchvision.models.resnet50(pretrained=True).cuda()
dummy_input = torch.zeros(1, 3, 224, 224).cuda()
oto = OTO(model=model, dummy_input=dummy_input)

#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions

- Set `view` as `False` if no browser is accessiable.
- Open the generated pdf file instead.

In [None]:
# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.
oto.visualize_zigs(view=False)

### Step 2. Dataset Preparation

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os

data_dir = "/data/imagenet" # Change to your own imagenet path
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'val')
batch_size = 128

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness=0.4,
        contrast=0.4,
        saturation=0.4,
        hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])

trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
testset = torchvision.datasets.ImageFolder(root=test_dir, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)

### Step 3. Setup DHSPG optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. 
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `start_pruning_steps`: The number of steps that start to prune. 
- `first_momentum`: The first-order momentum.
- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration.

In [1]:
optimizer = oto.dhspg(
    variant='sgd', 
    lr=0.1, 
    target_group_sparsity=0.4,
    weight_decay=0.0, # Some training set it as 1e-4.
    first_momentum=0.9,
    start_pruning_steps=15 * len(trainloader), # start pruning after 15 epochs. Start pruning at initialization stage.
    epsilon=0.95)

### Step 4. Train ResNet50 as normal.

Add functions related to `label smoothing` and `mix-up`.

In [None]:
import numpy as np
import torch.nn.functional as F

def one_hot(y, num_classes, smoothing_eps=None):
    if smoothing_eps is None:
        one_hot_y = F.one_hot(y, num_classes).float()
        return one_hot_y
    else:
        one_hot_y = F.one_hot(y, num_classes).float()
        v1 = 1 - smoothing_eps + smoothing_eps / float(num_classes)
        v0 = smoothing_eps / float(num_classes)
        new_y = one_hot_y * (v1 - v0) + v0
        return new_y

def cross_entropy_onehot_target(logit, target):
    # target must be one-hot format!!
    prob_logit = F.log_softmax(logit, dim=1)
    loss = -(target * prob_logit).sum(dim=1).mean()
    return loss

def mixup_func(input, target, alpha=0.2):
    gamma = np.random.beta(alpha, alpha)
    # target is onehot format!
    perm = torch.randperm(input.size(0))
    perm_input = input[perm]
    perm_target = target[perm]
    return input.mul_(gamma).add_(1 - gamma, perm_input), target.mul_(gamma).add_(1 - gamma, perm_target)

#### Start official training and compression.

Under $40\%$ group sparsity, the top-1 accuracy could reach $75.2\%$ on the specific run.

During our experiments, the top-1 accuracy is $(75.2\pm0.1)\%$ upon different random seeds.

In [None]:
from utils.utils import check_accuracy

label_smooth = True
mix_up = True
train_time = 2 # Mix-up requires longer training time for better convergence.
ckpt_dir = './' # Checkpoint save directory

max_epoch = 120
if not label_smooth:
    criterion = torch.nn.CrossEntropyLoss()
else:
    criterion = cross_entropy_onehot_target
    
# Every 30 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 
num_classes = 1000

best_acc_1 = 0.0

for epoch in range(max_epoch):
    f_avg_val = 0.0
    model.train()
    lr_scheduler.step()  
    for t in range(train_time):
        for X, y in trainloader:
            X = X.cuda()
            y = y.cuda()
            with torch.no_grad():
                if label_smooth and not mix_up:
                    y = one_hot(y, num_classes=num_classes, smoothing_eps=0.1)

                if not label_smooth and mix_up:
                    y = one_hot(y, num_classes=num_classes)
                    X, y = mixup_func(X, y)
                
                if mix_up and label_smooth:
                    y = one_hot(y, num_classes=num_classes, smoothing_eps=0.1)
                    X, y = mixup_func(X, y)
                    
            y_pred = model.forward(X)
            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            f_avg_val += f
            optimizer.step()
        group_sparsity, omega = optimizer.compute_group_sparsity_omega()
        norm_x_pm, norm_x_npm, num_groups_pm, num_groups_npm = optimizer.compute_norm_group_partitions()
        accuracy1, accuracy5 = check_accuracy(model, testloader)
        f_avg_val = f_avg_val.cpu().item() / len(trainloader)
        print("Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}"\
            .format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))

        if accuracy1 > best_acc_1:
            best_acc_1 = accuracy1
            torch.save(model, os.path.join(ckpt_dir, 'best_epoch_' + str(epoch) + '_' + str(t) + '.pt'))

### Step 5. Get compressed model in ONNX format

By default, OTO will compress the last checkpoint.

If we want to compress another checkpoint, need to reinitialize OTO then compress

    oto = OTO(model=torch.load(ckpt_path), dummy_input)

In [10]:
# A ResNet_compressed.onnx will be generated. 
oto.compress()

### (Optional) Compute FLOPs and number of parameters before and after OTO training

The compressed ResNet50 under 40% group sparsity reduces FLOPs by 61.5% and parameters by 50.6%.

In [11]:
full_flops = oto.compute_flops()
compressed_flops = oto.compute_flops(compressed=True)
full_num_params = oto.compute_num_params()
compressed_num_params = oto.compute_num_params(compressed=True)

print("Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))
print("Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}"\
      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))

Full FLOPs (M): 4089.18. Compressed FLOPs (M): 1574.71. Reduction Ratio: 0.6149
Full # Params: 25557032. Compressed # Params: 12626516. Reduction Ratio: 0.5059


### (Optional) Check the output difference between full model and compressed model.

#### Both full and compressed model should return the exact same output given the same input upon floating error.
#### The maximum deviation should be up to `1e-5` which is negligible. 


In [12]:
import onnxruntime as ort
full_ort_sess = ort.InferenceSession(oto.full_model_path)
compress_ort_sess = ort.InferenceSession(oto.compressed_model_path)

fake_input = torch.rand(1, 3, 224, 224)

full_output = full_ort_sess.run(None, {'input.1': fake_input.numpy()})[0]
compress_output = compress_ort_sess.run(None, {'input.1': fake_input.numpy()})[0]
print("Maximum output difference:")
print(np.max(np.abs(full_output - compress_output)))

Maximum output difference:
3.33786e-06
