In [57]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import os
from utils import progress_bar
from imp_baselines import*
# import imp_baselines as baselines

In [58]:
from ptflops import get_model_complexity_info


In [None]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

In [None]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):

    def __init__(self, cfg, classes=100):
        super(AlexNet, self).__init__()
        self.features = nn.Sequentiaal(
            nn.Conv2d(3, cfg[0], kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[0]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[1]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[2]),
            nn.Conv2d(cfg[2], cfg[3], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[3]),
            nn.Conv2d(cfg[3], cfg[4], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[4]),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(cfg[4] * 1 * 1, cfg[5]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(cfg[5], cfg[6]),
            nn.ReLU(inplace=True),
            nn.Linear(cfg[6], classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
cfg = [64, 192, 384, 256, 256, 4096, 4096]

In [None]:
net = AlexNet(cfg).to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
net

In [None]:
PATH_corr = './w_decorr/base_params/cifar100_net.pth'
net_dict = torch.load(PATH_corr)
net.load_state_dict(net_dict['net'])

### Accuracies

In [None]:
def cal_acc(net_test):
    net_test.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_test(inputs)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(100 * correct / total)
        
    return 100 * correct / total

### Data Driven Trimming

In [None]:
from imp_baselines import*
import imp_baselines as baselines

In [None]:
baselines

In [None]:
#Pre-Train AlexNet on google drive
net = AlexNet(cfg).to(device)
PATH_pre = './pretrained_alex.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net'])
relu_layers = [1,5,9,12,15]
classifier_relu_layers = [2, 5]
#Prune with testset
#Re-train / fine-Tune new Net
#Re-Test new net
#order and ratios 



In [None]:
from collections import Counter

In [None]:
def calc_importance_DD(net, relu_layers, classifier_relu_layers):

    total_neurons = [net.features[l_id-1].weight.shape[0] for l_id in relu_layers]  #array with # neurons in each layer
    for l_id in classifier_relu_layers:
        total_neurons.append(net.classifier[l_id-1].weight.shape[0])
    print((total_neurons))
    imp_matrix = np.zeros((sum(total_neurons),3)) #index in layer, layer index, importance
    print(imp_matrix.shape)
    a = []
    for i in range(len(total_neurons)-2):
        print(i)
        a = np.concatenate((a, np.ones(total_neurons[i])*(relu_layers[i] + 1)), axis=None)
    for i in range(len(total_neurons)-2, len(total_neurons)):
        print(i)
        a = np.concatenate((a, np.ones(total_neurons[i])*(classifier_relu_layers[i-5] - 1)), axis=None)

    imp_matrix[:,0] = a
    num = 0
    i = 0

    for l_id in relu_layers:
        neuron_order = cal_importance_dataDriven_conv(net, l_id, num_stop=100)  ##CHANGE NUM_STOP TO 40 OR SOMETHING
        imp_matrix[:,1][num:num+total_neurons[i]] = neuron_order[0]
        imp_matrix[:,2][num:num+total_neurons[i]] = neuron_order[1] / total_neurons[i]
        num += total_neurons[i]
        i += 1

    for l_id in classifier_relu_layers:
        neuron_order = cal_importance_dataDriven_linear(net, l_id, num_stop=100)  ##CHANGE NUM_STOP TO 40 OR SOMETHING
        imp_matrix[:,1][num:num+total_neurons[i]] = neuron_order[0]
        imp_matrix[:,2][num:num+total_neurons[i]] = neuron_order[1] / total_neurons[i]
        num += total_neurons[i]
        i += 1
    return imp_matrix

### TFO importance

In [None]:
import pickle

In [None]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in [2, 6, 10, 13, 16, 1, 4]:
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

# Pruning

In [None]:
def cfg_p(prune_ratio, orig_size, save_cfg_corr=0, save_cfg=0):
    cfg_list = []

    for i in range(7):
        cfg_list.append(orig_size[i] - prune_ratio[i])

    if(save_cfg == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)
    
    return cfg_list

In [None]:
vector = [1, 1, 384, 1, 256, 1, 4096]
nets = AlexNet(vector).to(device)


In [None]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    if(net_type==1):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=1)
    else:
        cfg = cfg_p(prune_ratio, orig_size)        
    print('Size cfg ',cfg)
    ###########
    ######New.  Enforce at least 1 neuron########
    for i in range(len(cfg)):
        if(cfg[i] == 0):
            cfg[i] = 1
    #########
    ##########
    print('Size cfg ',cfg)

    net_pruned = AlexNet(cfg).to(device)
    conv_layers = [2, 6, 10, 13, 16]
    lin_layers = [1, 4]
    
    for l in range(len(conv_layers)):
        if(l == 0):
            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[conv_layers[l]][n_c:])
            net_pruned.features[conv_layers[l]-2].weight.data = net.features[conv_layers[l]-2].weight[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]-2].bias.data = net.features[conv_layers[l]-2].bias[order_c].data.detach().clone()

            net_pruned.features[conv_layers[l]].weight.data = net.features[conv_layers[l]].weight[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]].bias.data = net.features[conv_layers[l]].bias[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]].running_var.data = net.features[conv_layers[l]].running_var[order_c].detach().clone()
            net_pruned.features[conv_layers[l]].running_mean.data = net.features[conv_layers[l]].running_mean[order_c].detach().clone()    
            continue
        
        n_p = prune_ratio[l-1]        
        n_c = prune_ratio[l]

        order_p = np.sort(imp_order[conv_layers[l-1]][n_p:])
        order_c = np.sort(imp_order[conv_layers[l]][n_c:])
        
        net_pruned.features[conv_layers[l]-2].weight.data = net.features[conv_layers[l]-2].weight[order_c][:,order_p].detach().clone()
        net_pruned.features[conv_layers[l]-2].bias.data = net.features[conv_layers[l]-2].bias[order_c].detach().clone()

        net_pruned.features[conv_layers[l]].weight.data = net.features[conv_layers[l]].weight[order_c].detach().clone()
        net_pruned.features[conv_layers[l]].bias.data = net.features[conv_layers[l]].bias[order_c].detach().clone()    
        net_pruned.features[conv_layers[l]].running_var.data = net.features[conv_layers[l]].running_var[order_c].detach().clone()
        net_pruned.features[conv_layers[l]].running_mean.data = net.features[conv_layers[l]].running_mean[order_c].detach().clone()    

    n_p = prune_ratio[4]        
    n_c = prune_ratio[5]
    order_p = np.sort(imp_order[conv_layers[4]][n_p:])
    order_c = np.sort(imp_order[lin_layers[0]][n_c:])    
    net_pruned.classifier[lin_layers[0]].weight.data = net.classifier[lin_layers[0]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[0]].bias.data = net.classifier[lin_layers[0]].bias[order_c].detach().clone()

    n_p = prune_ratio[5]        
    n_c = prune_ratio[6]
    order_p = np.sort(imp_order[lin_layers[0]][n_p:])
    order_c = np.sort(imp_order[lin_layers[1]][n_c:])    
    net_pruned.classifier[lin_layers[1]].weight.data = net.classifier[lin_layers[1]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[1]].bias.data = net.classifier[lin_layers[1]].bias[order_c].detach().clone()

    net_pruned.classifier[lin_layers[1]].weight.data = net.classifier[lin_layers[1]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[1]].bias.data = net.classifier[lin_layers[1]].bias[order_c].detach().clone()


    n_classifier = prune_ratio[-1]
    order_classifier = np.sort(imp_order[lin_layers[1]][n_classifier:])

    net_pruned.classifier[6].weight.data = net.classifier[6].weight[:,order_classifier].detach().clone()
    net_pruned.classifier[6].bias.data = net.classifier[6].bias.detach().clone()
    
    return net_pruned

In [None]:
### PRUNING
orig_size = [64, 192, 384, 256, 256, 4096, 4096]
prune_ratio = 0.6
imp_order, ratios = order_and_ratios(imp_matrix, prune_ratio)
print('ratios', (ratios))
print('orig_siz', (orig_size))
net_p = pruner(net, imp_order, ratios, orig_size, net_type=0)

In [None]:
ratios = [0, 0, 0, 0, 0, 0, 0]

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
#Prune 60%
net = AlexNet(cfg).to(device)
PATH_pre = './pretrained_alex.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net'])
relu_layers = [1,5,9,12,15]
classifier_relu_layers = [2, 5]
orig_size = [64, 192, 384, 256, 256, 4096, 4096]
prune_ratio = 0.6 ###You can change here
imp_matrix = calc_importance_DD(net, relu_layers, classifier_relu_layers)
print(imp_matrix)

imp_order, ratios = order_and_ratios(imp_matrix, prune_ratio)
net_p = pruner(net, imp_order, ratios, orig_size, net_type=0)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)
for i in range(10):
    net_p_train(200)


In [None]:
for i in range(50):
    net_p_train(200)
##Prune Ratio of 0.6 w/ no iterative pruning, accuracy = 46.1%.  Decreases by 4%


In [None]:
##Iterative pruning...
net = AlexNet(cfg).to(device)
PATH_pre = './pretrained_alex.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net'])
relu_layers = [1,5,9,12,15]
classifier_relu_layers = [2, 5]
orig_size = [64, 192, 384, 256, 256, 4096, 4096]

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)
prune_ratios = [0.30,0.25,0.20] #This is 58% pruned

for prune_ratio in prune_ratios:
    imp_matrix = calc_importance_DD(net, relu_layers, classifier_relu_layers)
    imp_order, ratios = order_and_ratios(imp_matrix, prune_ratio)
    net_p = pruner(net, imp_order, ratios, orig_size, net_type=0)
    orig_size = ratios
    for i in range(30):
        net_p_train(200)

In [None]:
##Run more 
###Save the pruned network 


#RESULTS
#60% ITERATIVE PRUNING, 50.781% accuracy
#60% intitial pruning, 46.43% accuracy

In [None]:
for i in range(12):
    net_p_train(100)

In [None]:
cal_acc(net_p.eval())

In [None]:
net_p

In [109]:
cfg = [64,158,264,203,144,138,106]
#cfg = [64, 192, 384, 256, 256, 4096, 4096]
net = AlexNet(cfg).to(device)
PATH_pre = './ortho_p_final_ckpts/nets/ortho_ckpt12.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net_p_ortho'])
cal_acc(net.eval())

51.96


51.96

In [110]:
cfg = [64,157,212,186,113,138,100]
#cfg = [64, 192, 384, 256, 256, 4096, 4096]
net = AlexNet(cfg).to(device)
PATH_pre = './ortho_p_final_ckpts/nets/ortho_ckpt13.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net_p_ortho'])
cal_acc(net.eval())

50.33


50.33

In [86]:
net_p = net
net_p_train(2)


Epoch: 2
0
Saving.....................]  Step: 21s567ms | Tot: 0ms | Loss: 1.494 | Acc: 64.062% (82/12 1/391 
1
Saving.....................]  Step: 379ms | Tot: 380ms | Loss: 1.599 | Acc: 61.328% (157/25 2/391 
2
Saving.....................]  Step: 346ms | Tot: 726ms | Loss: 1.719 | Acc: 57.031% (219/38 3/391 
3
Saving.....................]  Step: 358ms | Tot: 1s85ms | Loss: 1.808 | Acc: 53.125% (272/51 4/391 
4
Saving.....................]  Step: 356ms | Tot: 1s442ms | Loss: 1.887 | Acc: 50.781% (325/64 5/391 
5
6[>........................]  Step: 368ms | Tot: 1s810ms | Loss: 1.972 | Acc: 48.698% (374/76 6/391 
7[>........................]  Step: 350ms | Tot: 2s160ms | Loss: 2.016 | Acc: 47.545% (426/89 7/391 
8[>........................]  Step: 379ms | Tot: 2s539ms | Loss: 2.035 | Acc: 47.070% (482/102 8/391 
9[>........................]  Step: 367ms | Tot: 2s907ms | Loss: 2.049 | Acc: 46.528% (536/115 9/391 
10>........................]  Step: 341ms | Tot: 3s248ms | Loss: 2.055 | A





In [87]:
cal_acc(net_p.eval())
#net_p_test(2)

34.94


34.94

In [94]:
###ADVERSARIAL ATTACK ON PRUNNED
import foolbox as fb
import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples



fmodel = fb.PyTorchModel(net_p.eval(), bounds=(-3.3,3.3))
images, labels = ep.astensors(*samples(fmodel, dataset="cifar100", batchsize=16))

attack = fb.attacks.LinfPGD()
epsilons = [0.0, 0.001, 0.005, 0.01, 0.03, 0.1, 0.3, 0.5]
advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)

In [95]:
# calculate and report the robust accuracy
print('Robust accuracy for Pruned Network (epsilon, accuracy)')
robust_accuracy = 1 - success.float32().mean(axis=-1)
for eps, acc in zip(epsilons, robust_accuracy):
    print('eps:',eps, '||','acc:', acc.item())


Robust accuracy for Pruned Network (epsilon, accuracy)
eps: 0.0 || acc: 0.0
eps: 0.001 || acc: 0.0
eps: 0.005 || acc: 0.0
eps: 0.01 || acc: 0.0
eps: 0.03 || acc: 0.0
eps: 0.1 || acc: 0.0
eps: 0.3 || acc: 0.0
eps: 0.5 || acc: 0.0


In [100]:
###ADVERSARIAL ATTACK ON Pre-Trained
cfg = [64, 192, 384, 256, 256, 4096, 4096]
net = AlexNet(cfg).to(device)
PATH_pre = './pretrained_alex.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net'])

fmodel = fb.PyTorchModel(net.eval(), bounds=(0,1))
images, labels = ep.astensors(*samples(fmodel, dataset="cifar100", batchsize=16))

attack = fb.attacks.LinfPGD()
epsilons = [0.0, 0.001, 0.005, 0.01, 0.03, 0.1, 0.3, 0.5]
advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)

# calculate and report the robust accuracy
robust_accuracy = 1 - success.float32().mean(axis=-1)
print('Robust accuracy for Pre-Trained Network (epsilon, accuracy)')
for eps, acc in zip(epsilons, robust_accuracy):
    print('eps:',eps, '||','acc:', acc.item())


Robust accuracy for Pre-Trained Network (epsilon, accuracy)
eps: 0.0 || acc: 0.125
eps: 0.001 || acc: 0.125
eps: 0.005 || acc: 0.0625
eps: 0.01 || acc: 0.0625
eps: 0.03 || acc: 0.0
eps: 0.1 || acc: 0.0
eps: 0.3 || acc: 0.0
eps: 0.5 || acc: 0.0


In [98]:
import torchvision.models as models
from foolbox.attacks import LinfPGD

model = models.resnet18(pretrained=True).eval()
preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)
fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

# get data and test the model
# wrapping the tensors with ep.astensors is optional, but it allows
# us to work with EagerPy tensors in the following
images, labels = ep.astensors(*samples(fmodel, dataset="imagenet", batchsize=16))
print(accuracy(fmodel, images, labels))

# apply the attack
attack = LinfPGD()
epsilons = [0.0, 0.001, 0.01, 0.03, 0.1, 0.3, 0.5, 1.0]
advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)

# calculate and report the robust accuracy
robust_accuracy = 1 - success.float32().mean(axis=-1)
for eps, acc in zip(epsilons, robust_accuracy):
    print(eps, acc.item())

# we can also manually check this
for eps, advs_ in zip(epsilons, advs):
    print(eps, accuracy(fmodel, advs_, labels))
    # but then we also need to look at the perturbation sizes
    # and check if they are smaller than eps
    print((advs_ - images).norms.linf(axis=(1, 2, 3)).numpy())

0.9375
0.0 0.9375
0.001 0.25
0.01 0.0
0.03 0.0
0.1 0.0
0.3 0.0
0.5 0.0
1.0 0.0
0.0 0.9375
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
0.001 0.25
[0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001
 0.001 0.001 0.001 0.001]
0.01 0.0
[0.01000002 0.01000002 0.01000002 0.01000002 0.01000002 0.01000002
 0.01000002 0.01000002 0.01000002 0.01000002 0.01000002 0.01000002
 0.01000002 0.01000002 0.01000002 0.01000002]
0.03 0.0
[0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03 0.03
 0.03 0.03]
0.1 0.0
[0.10000002 0.10000002 0.10000002 0.10000002 0.10000002 0.10000002
 0.10000002 0.10000002 0.10000002 0.10000002 0.10000002 0.10000002
 0.10000002 0.10000002 0.10000002 0.10000002]
0.3 0.0
[0.30000004 0.30000004 0.30000004 0.30000004 0.30000004 0.30000004
 0.30000004 0.30000004 0.30000004 0.30000004 0.30000004 0.30000004
 0.30000004 0.30000004 0.30000004 0.30000004]
0.5 0.0
[0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
1.0 0.0
[1. 1. 1. 1

# Retraining

In [None]:
prune_iter = 1

## Pruning

In [None]:
orig_size = []

for i in [2, 6, 10, 13, 16]:
    orig_size.append(net_corr.features[i].bias.shape[0])

for i in [1, 4]:
    orig_size.append(net_corr.classifier[i].bias.shape[0])
    
orig_size = np.array(orig_size)

In [None]:
order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.2)
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])
net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

In [None]:
cal_acc(net_p.eval()), cal_acc(net_corr.eval())

#### Retraining

In [83]:
# Training
def net_p_train(epoch):
    print('\nEpoch: %d' % epoch)
    net_p.train()
    train_loss = 0
    correct = 0
    total = 0
    optimizer = optim.Adam(net_p.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        print(batch_idx)
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc =  100.*correct/total
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), acc, correct, total))
        if acc > 50.0:
            print('Saving..')
            state = {
                'net_p': net_p.state_dict(),
                'best_p_acc': acc
            }
            if not os.path.isdir('net_p_checkpoint'):
                os.mkdir('net_p_checkpoint')
            torch.save(state, './net_p_checkpoint/ckpt'+str(51)+'.pth')
            best_p_acc = acc
def net_p_test(epoch):
    global best_p_acc
    global prune_iter
    net_p.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > 50.0:
        print('Saving..')
        state = {
            'net_p': net_p.state_dict(),
            'best_p_acc': acc
        }
        if not os.path.isdir('net_p_checkpoint'):
            os.mkdir('net_p_checkpoint')
        torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
        best_p_acc = acc

In [84]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
best_p_acc = 0

In [None]:
for epoch in range(1):
    net_p_train(epoch)
    net_p_test(epoch)

#### Load correlated pruned network

In [None]:
prune_iter

In [None]:
net_dict = torch.load('./net_p_checkpoint/ckpt1.pth')
net_p.load_state_dict(net_dict['net_p'])
best_p_acc = net_dict['best_p_acc']

### Subsequent pruning

#### Importance

In [None]:
''' Correlated network '''
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     imp_order_p = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
imp_order_p = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 6, 10, 13, 16, 1, 4]:
    print(l_index)
    if(l_index != 1 and l_index != 4):
        nlist = cal_importance_conv(net_p, l_index)
    else:
        nlist = cal_importance_linear(net_p, l_index)
    imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p, f)

#### Pruned network pruning

In [None]:
''' Correlated network '''
orig_size = []
for i in [2, 6, 10, 13, 16]:
    orig_size.append(net_p.features[i].bias.shape[0])
for i in [1, 4]:
    orig_size.append(net_p.classifier[i].bias.shape[0])
orig_size = np.array(orig_size)

#### Pruning order

In [None]:
''' Correlated network '''
order_p, prune_ratio = order_and_ratios(imp_order_p, 0.1)
prune_ratio, orig_size

#### Define pruned network

In [None]:
prune_iter = 2

In [None]:
''' Correlated network pruning '''
net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=1)

print("Accs:", cal_acc(net_p1.eval()), cal_acc(net_p.eval()))

#### Save pruned network

In [None]:
''' Correlated network saving '''
net_p = net_p1

print('Saving..')
state = {
    'net_p': net_p.state_dict(),
    'best_p_acc': cal_acc(net_p.eval())
}
if not os.path.isdir('net_p_checkpoint'):
    os.mkdir('net_p_checkpoint')
torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')

### Load pruned network

In [None]:
# ''' Correlated network loading '''
# with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(1)+".pkl", 'rb') as f:
#     cfg_p1 = pickle.load(f)
    
# net_p = AlexNet(cfg_p1).to(device)
# PATH = './net_p_checkpoint/ckpt'+str(1)+'.pth'
# net_p.load_state_dict(torch.load(PATH)['net_p'])

In [None]:
# cal_acc(net_p.eval()), cal_acc(net_decorr.eval())

### FLOPS calculator

In [None]:
#with torch.cuda.device(0):
flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
print('{:<30}  {:<8}'.format('Computational complexity 60% Pruned: ', flops))    

In [None]:
flops, params = get_model_complexity_info(net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
print('{:<30}  {:<8}'.format('Computational complexity Original: ', flops))