# PENNI framework workflow

Note: For the ResNet56 checkpoint, please use dill package instead of torch.load to load it.

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np

from models import utils, vgg16
from models.op_count import profile
from decompose import params_resolver

## Load Baseline Model

In [None]:
model = vgg16.VGG16()
model.load_state_dict(torch.load("ckpt/Baseline/VGG16_93.49.pth"))
model.cuda()
utils.eval_cifar10(model)

## Decomposition and Retraining

In [None]:
resolver = params_resolver.param_resolver(model)
spar_model = resolver.PCA_decomposing(basis=5)
spar_model.cuda()

utils.train_cifar10(spar_model, lr=0.01, reg=1e-4, cross=True, cross_interval=2, spar_reg="l1", spar_param=1e-4)
utils.eval_cifar10(model)

## Parameter and FLOPs Count

In [None]:
spar_model.cuda()
inputs = torch.randn(1,3,32,32).cuda()
flops = profile.profile(spar_model, [inputs,], verbose=False)
print(profile.clever_format(flops))

utils.compute_sparsity(spar_model)

## Prune and Finetune

In [None]:
utils.prune_by_std(spar_model)
utils.train_cifar10(spar_model, lr=0.01, reg=1e-4, epochs=30, finetune=True)
utils.eval_cifar10(spar_model)

In [None]:
flops = profile.profile(spar_model, [inputs,], verbose=False)
print(profile.clever_format(flops))

utils.compute_sparsity(spar_model)

## Model Shrinking

In [None]:
spar_model = torch.load("ckpt/PrunedFinalModel/VGG16_pruned_9312.h5")
model_s, _ = utils.shrink(spar_model, iterative=True)

model_s.cuda()
utils.eval_cifar10(model_s)

In [None]:
flops = profile.profile(spar_model, [inputs,], verbose=False)
print(profile.clever_format(flops))

utils.compute_sparsity(spar_model)