In [None]:
import torch
from backends import DemoNet

### Create OTO instance

In [None]:
from only_train_once import OTO
model = DemoNet()
dummy_input = torch.zeros(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=torch.zeros(1, 3, 32, 32).cuda())

#### Optional: Visualize the dependancy graph of DNN for ZIG partitions

In [None]:
# A DemoNet.gv.pdf will be generated to display the depandancy graph.
oto.visualize_zigs()

### Dataset Preparation

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))

trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

### Setup DHSPG optimizer

In [None]:
optimizer = oto.dhspg(lr=0.1, target_group_sparsity=0.7)

### Start training

In [None]:
from tqdm import tqdm
from utils.utils import check_accuracy

max_epoch = 60
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

losses = list()
group_sparsies = list()
accuracies = list()

for epoch in range(max_epoch):
    f_avg_val = 0.0
    lr_scheduler.step()
    model.train()
    for X, y in trainloader:
        X = X.cuda()
        y = y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    group_sparsity, omega = optimizer.compute_group_sparsity_omega()
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    losses.append(f_avg_val)
    group_sparsies.append(group_sparsity)
    accuracies.append(accuracy1)
    print("Epoch: {ep}, loss: {f:.2f}, omega: {omega:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}".format(ep=epoch, f=f_avg_val, omega=omega, gs=group_sparsity, acc=accuracy1))

### Get compressed model in ONNX format

In [None]:
oto.compress()

full_flops = oto.compute_flops()['total']
compressed_flops = oto.compute_flops(compressed=True)['total']
full_num_params = oto.compute_num_params()
compressed_num_params = oto.compute_num_params(compressed=True)

print("Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio}"\
      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))
print("Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio}"\
      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))

### Check the compressed model accuracy

#### Both full and compressed model should return the exact same accuracy

In [None]:
from utils.utils import check_accuracy_onnx
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

acc1_full, acc5_full = check_accuracy(model, testloader)
print("Full torch model: Acc 1: ", acc1_full, ", Acc 5: ", acc5_full)

acc1_compressed, acc5_compressed_onnx = check_accuracy_onnx(oto.compressed_model_path, testloader)
print("Acc 1: ", acc1_compressed, ", Acc 5: ", acc5_compressed_onnx)  