In [3]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import os
from utils import progress_bar
from imp_baselines import*
# import imp_baselines as baselines

In [4]:
# from ptflops import get_model_complexity_info

In [5]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified


In [6]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):

    def __init__(self, cfg, classes=100):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, cfg[0], kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[0]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[1]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[2]),
            nn.Conv2d(cfg[2], cfg[3], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[3]),
            nn.Conv2d(cfg[3], cfg[4], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[4]),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(cfg[4] * 1 * 1, cfg[5]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(cfg[5], cfg[6]),
            nn.ReLU(inplace=True),
            nn.Linear(cfg[6], classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [8]:
cfg = [64, 192, 384, 256, 256, 4096, 4096]

In [9]:
net_corr = AlexNet(cfg).to(device)
net_decorr = AlexNet(cfg).to(device)
criterion = nn.CrossEntropyLoss()

In [10]:
PATH_corr = './w_decorr/base_params/cifar100_net.pth'
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])

PATH_decorr = './w_decorr/base_params/wnet_base.pth'
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net'])

FileNotFoundError: [Errno 2] No such file or directory: './w_decorr/base_params/cifar100_net.pth'

In [11]:
trainset.data[1].shape

(32, 32, 3)

In [12]:
net_corr

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [None]:
x = torch.rand((1,3,32,32))

outputAt2 = net_corr.features[0:2](x)

z = net_corr.features[2:6](y)

### Accuracies

In [13]:
def cal_acc(net_test):
    net_test.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_test(inputs)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(100 * correct / total)
        
    return 100 * correct / total

### Importance

In [None]:
def cal_importance_conv(net, l_index):
    bias_base = net.features[l_index].bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)

    imp_norm = imp_corr_bn
    
    neuron_order = [np.linspace(0, imp_norm.shape[0]-1, imp_norm.shape[0]), imp_norm]
    
    return neuron_order

In [None]:
def cal_importance_linear(net, l_index):
    bias_base = net.classifier[l_index].bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((net.classifier[l_index].weight.grad)*(net.classifier[l_index].weight.data)).sum(dim=1) + ((net.classifier[l_index].bias.grad)*(net.classifier[l_index].bias.data))).pow(2)

    imp_norm = imp_corr_bn
    
    neuron_order = [np.linspace(0, imp_norm.shape[0]-1, imp_norm.shape[0]), imp_norm]
    
    return neuron_order

### Timer

In [None]:
import time

In [62]:
def cal_time(net_acc):
    net_acc.eval()
    testsamp = torch.rand(1,3,32,32).to(device)
    
    for i in range(5):
        net_acc(testsamp)    
    t_end = 0
    t_s = time.time()
    for i in range(25):
        net_acc(testsamp)
        t_end += time.time() - t_s
    
    return (t_end / 25)

In [None]:
t_corr = cal_time(net_corr)
t_decorr = cal_time(net_decorr)

### Data Driven Trimming

In [14]:
from imp_baselines import*
import imp_baselines as baselines

In [15]:
baselines

<module 'imp_baselines' from '/Users/skylergranatir/Desktop/EECS545/FinalProject/net-compression-master/working_code/alexnet/imp_baselines.py'>

In [53]:
#Pre-Train AlexNet on google drive
net = AlexNet(cfg).to(device)
PATH_pre = './pretrained_alex.pth'
net_dict = torch.load(PATH_pre, map_location=torch.device('cpu'))
net.load_state_dict(net_dict['net'])
relu_layers = [1,5,9,12,15]
#Prune with testset
#Re-train / fine-Tune new Net
#Re-Test new net
#order and ratios 



In [17]:
neuron_order = cal_importance_dataDriven(net, 1, num_stop=10)

Working 1
Sample  1


In [18]:
from collections import Counter

In [19]:
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
num = 0
size = net.features[0].weight.shape[0]
aPoZ = np.zeros(size) #size of the conv layer before reLu...
for i, data in enumerate(trainloader, 0):
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()


    num += 1 
    if(num > 10):
        break
    print('Sample ', num)

    ##for neuron in post relu...
    ##if activation of weight is zero, add 1 to List at the weights index
    ##should this be before back propogation?
    l_id_output = net.features[0:2](inputs)  #output of the relu layer 
    # print(l_id_output.shape)

    for a in range(0,l_id_output.shape[1]):
           # print(l_id_output[:,a,:,:].shape)
            sum_ = 0
            sum_ = (0,np.count_nonzero(l_id_output.detach().numpy()[:,a,:,:] == 0))
            aPoZ[a] += sum_[1]*0.001

imp_corr_bn = 1/aPoZ #so that entries with highest aPoZ have lowest importance
print(imp_corr_bn)
neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
print('Neuron Order shape', len(neuron_order), len(neuron_order[1]))
#index in layer, layer index, importance.... m x 3 array.

# #  imp_corr_bn = np.argsort(aPoZ)
# imp_corr_bn = aPoZ
# neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
#index in layer, layer index, importance.... m x 3 array.


Sample  1
Sample  2
Sample  3
Sample  4
Sample  5
Sample  6
Sample  7
Sample  8
Sample  9
Sample  10
[0.00484222 0.00610448 0.00763435 0.00427466 0.01091548 0.01008776
 0.02645923 0.00392362 0.00666285 0.00564299 0.01130889 0.00822558
 0.00533459 0.00756664 0.01671598 0.01022987 0.00484639 0.00589564
 0.00417866 0.00997218 0.00421914 0.00387312 0.00745573 0.00579391
 0.00372617 0.00931584 0.0274763  0.01244617 0.01432562 0.00591632
 0.00622622 0.00516748 0.00925969 0.00560381 0.00793065 0.00854723
 0.00578352 0.00892108 0.00454756 0.00690593 0.01211402 0.01118418
 0.00838251 0.00828899 0.01179329 0.00394123 0.0043793  0.00863632
 0.00708306 0.00699061 0.00735321 0.00720939 0.01014868 0.00505976
 0.00643927 0.00348453 0.00599298 0.00828356 0.00741032 0.00842119
 0.00374348 0.00789235 0.00489877 0.00334543]
Neuron Order shape 2 64


In [101]:
total_neurons = [net.features[l_id-1].weight.shape[0] for l_id in relu_layers]  #array with # neurons in each layer
print((total_neurons))
imp_matrix = np.zeros((sum(total_neurons),3)) #index in layer, layer index, importance
print(imp_matrix.shape)
a = []
for i in range(len(total_neurons)):
    a = np.concatenate((a, np.ones(total_neurons[i])*relu_layers[i]), axis=None)

imp_matrix[:,1] = a
num = 0
i = 0

for l_id in relu_layers:
    neuron_order = cal_importance_dataDriven(net, l_id, num_stop=1)  ##CHANGE NUM_STOP TO 40 OR SOMETHING
    imp_matrix[:,0][num:num+total_neurons[i]] = neuron_order[0]
    imp_matrix[:,2][num:num+total_neurons[i]] = neuron_order[1]
    num += total_neurons[i]
    i += 1

print(imp_matrix)

[64, 192, 384, 256, 256]
(1152, 3)
Working 1
Sample  1
Working 1
Sample  1
Working 1
Sample  1
Working 1
Sample  1
Working 1
Sample  1
[[0.00000000e+00 1.00000000e+00 4.97388728e-02]
 [1.00000000e+00 1.00000000e+00 6.20155036e-02]
 [2.00000000e+00 1.00000000e+00 7.34861866e-02]
 ...
 [2.53000000e+02 1.50000000e+01 1.89035928e+00]
 [2.54000000e+02 1.50000000e+01 1.44092214e+00]
 [2.55000000e+02 1.50000000e+01 1.45985401e+00]]


In [103]:
### PRUNING
orig_size = [64, 192, 384, 256, 256, 4096, 4096]
prune_ratio = 0.3
imp_order, ratios = order_and_ratios(imp_matrix, prune_ratio)
print('ratios', len(ratios))
print('orig_siz', len(orig_size))
net_pruned = pruner(net, imp_order, ratios, orig_size, net_type=0)
##I changed pruner to concatenate the 4096,4096, thus they are unpruned.... Is this right? 

ratios 7
orig_siz 7


In [105]:
net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [109]:
for i in range(17):
    if(i%3==0 or i == 1 or i == 5 or i == 7):
        continue
    print('layer', i, net.features[i].weight.shape)

layer 2 torch.Size([64])
layer 4 torch.Size([192, 64, 3, 3])
layer 8 torch.Size([384, 192, 3, 3])
layer 10 torch.Size([384])
layer 11 torch.Size([256, 384, 3, 3])
layer 13 torch.Size([256])
layer 14 torch.Size([256, 256, 3, 3])
layer 16 torch.Size([256])


In [111]:
for i in range(17):
    if(i%3==0 or i == 1 or i == 5 or i == 7):
        continue
    print('layer', i, net_pruned.features[i].weight.shape)

layer 2 torch.Size([3])
layer 4 torch.Size([3, 3, 3, 3])
layer 8 torch.Size([2, 3, 3, 3])
layer 10 torch.Size([2])
layer 11 torch.Size([3, 2, 3, 3])
layer 13 torch.Size([3])
layer 14 torch.Size([3, 3, 3, 3])
layer 16 torch.Size([3])


In [74]:
#Re-Train...
optimizer = optim.SGD(net_pruned.parameters(), lr=1e-4, momentum=0.9)
for i, data in enumerate(trainloader, 0):
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = net_pruned(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

In [43]:
#Re-Train...
optimizer = optim.SGD(net.parameters(), lr=1e-4, momentum=0.9)
for i, data in enumerate(trainloader, 0):
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

In [75]:
#Test, want to get around 50.96
cal_acc(net_pruned.eval())

1.0


1.0

### TFO importance

In [None]:
import pickle

In [None]:
with open("./w_decorr/base_params/tfo_corr.pkl", 'rb') as f:
    imp_order_corr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_corr.parameters(), lr=0, weight_decay=0)
imp_order_corr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 6, 10, 13, 16, 1, 4]:
    print(l_index)
    if(l_index != 1 and l_index != 4):
        nlist = cal_importance_conv(net_corr, l_index)
    else:
        nlist = cal_importance_linear(net_corr, l_index)
    imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_corr.pkl", 'wb') as f:
    pickle.dump(imp_order_corr, f)

In [None]:
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'rb') as f:
    imp_order_decorr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
imp_order_decorr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 6, 10, 13, 16, 1, 4]:
    print(l_index)
    if(l_index != 1 and l_index != 4):
        nlist = cal_importance_conv(net_decorr, l_index)
    else:
        nlist = cal_importance_linear(net_decorr, l_index)
    imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'wb') as f:
    pickle.dump(imp_order_decorr, f)

In [22]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in [2, 6, 10, 13, 16, 1, 4]:
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

# Pruning

In [23]:
def cfg_p(prune_ratio, orig_size, save_cfg_corr=0, save_cfg=0):
    cfg_list = []

    for i in range(7):
        cfg_list.append(orig_size[i] - prune_ratio[i])
    
    if(save_cfg == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    elif(save_cfg == 2):
        with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)
            
    return cfg_list

In [24]:
def func(a,b):
    return max(a,b)

max(0,1)
a = np.zeros((2,4))
a[0,:] = [1,2,3,4]
a[1,:] = [4,5,6,7]
print(a)

[[1. 2. 3. 4.]
 [4. 5. 6. 7.]]


In [None]:
y = [0,1]
m = 2
a[0,:] = [-1,2,3,4]
a[1,:] = [4,5,6,7]
print(np.maximum(a,0))

In [None]:
np.max(a)

In [25]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    
    if(net_type==1):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=1)
    elif(net_type==2):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=2)
    else:
        cfg = cfg_p(prune_ratio, orig_size)
        
    cfg = np.concatenate((cfg,[4096,4096]), axis=None)
    
    net_pruned = AlexNet(cfg).to(device)
    conv_layers = [2, 6, 10, 13, 16]
    lin_layers = [1, 4]
    
    for l in range(len(conv_layers)):
        if(l == 0):
            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[conv_layers[l]][n_c:])
            net_pruned.features[conv_layers[l]-2].weight.data = net.features[conv_layers[l]-2].weight[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]-2].bias.data = net.features[conv_layers[l]-2].bias[order_c].data.detach().clone()

            net_pruned.features[conv_layers[l]].weight.data = net.features[conv_layers[l]].weight[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]].bias.data = net.features[conv_layers[l]].bias[order_c].data.detach().clone()
            net_pruned.features[conv_layers[l]].running_var.data = net.features[conv_layers[l]].running_var[order_c].detach().clone()
            net_pruned.features[conv_layers[l]].running_mean.data = net.features[conv_layers[l]].running_mean[order_c].detach().clone()    
            continue
        
        n_p = prune_ratio[l-1]        
        n_c = prune_ratio[l]

        order_p = np.sort(imp_order[conv_layers[l-1]][n_p:])
        order_c = np.sort(imp_order[conv_layers[l]][n_c:])
        
        net_pruned.features[conv_layers[l]-2].weight.data = net.features[conv_layers[l]-2].weight[order_c][:,order_p].detach().clone()
        net_pruned.features[conv_layers[l]-2].bias.data = net.features[conv_layers[l]-2].bias[order_c].detach().clone()

        net_pruned.features[conv_layers[l]].weight.data = net.features[conv_layers[l]].weight[order_c].detach().clone()
        net_pruned.features[conv_layers[l]].bias.data = net.features[conv_layers[l]].bias[order_c].detach().clone()    
        net_pruned.features[conv_layers[l]].running_var.data = net.features[conv_layers[l]].running_var[order_c].detach().clone()
        net_pruned.features[conv_layers[l]].running_mean.data = net.features[conv_layers[l]].running_mean[order_c].detach().clone()    

    n_p = prune_ratio[4]        
    n_c = prune_ratio[5]
    order_p = np.sort(imp_order[conv_layers[4]][n_p:])
    order_c = np.sort(imp_order[lin_layers[0]][n_c:])    
    net_pruned.classifier[lin_layers[0]].weight.data = net.classifier[lin_layers[0]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[0]].bias.data = net.classifier[lin_layers[0]].bias[order_c].detach().clone()

    n_p = prune_ratio[5]        
    n_c = prune_ratio[6]
    order_p = np.sort(imp_order[lin_layers[0]][n_p:])
    order_c = np.sort(imp_order[lin_layers[1]][n_c:])    
    net_pruned.classifier[lin_layers[1]].weight.data = net.classifier[lin_layers[1]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[1]].bias.data = net.classifier[lin_layers[1]].bias[order_c].detach().clone()

    net_pruned.classifier[lin_layers[1]].weight.data = net.classifier[lin_layers[1]].weight[order_c][:,order_p].detach().clone()
    net_pruned.classifier[lin_layers[1]].bias.data = net.classifier[lin_layers[1]].bias[order_c].detach().clone()


    n_classifier = prune_ratio[-1]
    order_classifier = np.sort(imp_order[lin_layers[1]][n_classifier:])

    net_pruned.classifier[6].weight.data = net.classifier[6].weight[:,order_classifier].detach().clone()
    net_pruned.classifier[6].bias.data = net.classifier[6].bias.detach().clone()
    
    return net_pruned

# Retraining

In [None]:
prune_iter = 1

## Correlated network pruning

In [None]:
orig_size = []

for i in [2, 6, 10, 13, 16]:
    orig_size.append(net_corr.features[i].bias.shape[0])

for i in [1, 4]:
    orig_size.append(net_corr.classifier[i].bias.shape[0])
    
orig_size = np.array(orig_size)

In [None]:
order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.2)
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])
net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

In [None]:
cal_acc(net_p.eval()), cal_acc(net_corr.eval())

In [None]:
100*(1 - cal_time(net_p) / t_corr)

#### Retraining

In [None]:
# Training
def net_p_train(epoch):
    print('\nEpoch: %d' % epoch)
    net_p.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_p_test(epoch):
    global best_p_acc
    global prune_iter
    net_p.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_p_acc:
        print('Saving..')
        state = {
            'net_p': net_p.state_dict(),
            'best_p_acc': acc
        }
        if not os.path.isdir('net_p_checkpoint'):
            os.mkdir('net_p_checkpoint')
        torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
        best_p_acc = acc

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
best_p_acc = 0

In [None]:
for epoch in range(1):
    net_p_train(epoch)
    net_p_test(epoch)

#### Load correlated pruned network

In [None]:
prune_iter

In [None]:
net_dict = torch.load('./net_p_checkpoint/ckpt1.pth')
net_p.load_state_dict(net_dict['net_p'])
best_p_acc = net_dict['best_p_acc']

## Decorrelated network pruning

In [11]:
orig_size = []
for i in [2, 6, 10, 13, 16]:
    orig_size.append(net_corr.features[i].bias.shape[0])
for i in [1, 4]:
    orig_size.append(net_corr.classifier[i].bias.shape[0])
orig_size = np.array(orig_size)

In [None]:
order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, 0.2)
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net'])
net_p_ortho = pruner(net_decorr, order_decorr, prune_ratio, orig_size, net_type=2)

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

In [None]:
100*(1 - cal_time(net_p_ortho) / t_decorr)

In [None]:
def net_p_test_ortho(epoch):
    global best_p_ortho_acc
    global prune_iter
    net_p_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_p_ortho_acc:
        print('Saving..')
        state = {
            'net_p_ortho': net_p_ortho.state_dict(),
            'best_p_ortho_acc': acc
        }
        if not os.path.isdir('ortho_p_checkpoint'):
            os.mkdir('ortho_p_checkpoint')
        torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
        best_p_ortho_acc = acc

In [None]:
def net_p_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_p_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_p_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_p_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_p_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [6, 10, 13, 16]:
            w_mat = net_p_ortho.features[conv_ind-2].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_p_ortho.features[conv_ind-2].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_p_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_p_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_p_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_p_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

#### Computational Importance

In [None]:
l_imp = {}

for conv_ind in [2, 6, 10, 13, 16]:
    l_imp.update({conv_ind: net_p_ortho.features[conv_ind].bias.shape[0]})
    
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net_p_ortho.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
# l_imp[0] = 0 #l_imp[31]

#### Retraining

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p_ortho.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
best_p_ortho_acc = 0

In [None]:
for epoch in range(1):
    net_p_train_ortho(epoch)
    net_p_test_ortho(epoch)

#### Load decorrelated pruned network

In [None]:
prune_iter 

In [None]:
net_dict = torch.load('./ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
net_p_ortho.load_state_dict(net_dict['net_p_ortho'])

#### Evaluate orthogonality of filters in pruned network

In [None]:
### Conv_ind == 0 ###
w_mat = net_p_ortho.features[0].weight
w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
b_mat = net_p_ortho.features[0].bias
b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
params = torch.cat((w_mat1, b_mat1), dim=1)
angle_mat = torch.matmul(torch.t(params), params)# - torch.eye(params.shape[1]).to(device)
L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))
print(L_diag.cpu()/L_angle.cpu())

for conv_ind in [6, 10, 13, 16]:
    w_mat = net_p_ortho.features[conv_ind-2].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net_p_ortho.features[conv_ind-2].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(params, torch.t(params))# - torch.eye(w_mat.shape[0]).to(device)            
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

### lin_ind = 1 ###        
w_mat = net_p_ortho.classifier[1].weight
w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
b_mat = net_p_ortho.classifier[1].bias
b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
params = torch.cat((w_mat1, b_mat1), dim=1)
angle_mat = torch.matmul(torch.t(params), params)# - torch.eye(params.shape[1]).to(device)
L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))
print(L_diag.cpu()/L_angle.cpu())

### lin_ind = 4 ###        
w_mat = net_p_ortho.classifier[4].weight
w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
b_mat = net_p_ortho.classifier[4].bias
b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
params = torch.cat((w_mat1, b_mat1), dim=1)
angle_mat = torch.matmul(params, torch.t(params))# - torch.eye(params.shape[0]).to(device)
L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))
print(L_diag.cpu()/L_angle.cpu())

### Subsequent pruning

#### Importance

In [None]:
''' Correlated network '''
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     imp_order_p = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
imp_order_p = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 6, 10, 13, 16, 1, 4]:
    print(l_index)
    if(l_index != 1 and l_index != 4):
        nlist = cal_importance_conv(net_p, l_index)
    else:
        nlist = cal_importance_linear(net_p, l_index)
    imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p, f)

In [None]:
''' De-Correlated network '''
# with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'rb') as f:
#     imp_order_p_ortho = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p_ortho.parameters(), lr=0, weight_decay=0)
imp_order_p_ortho = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 6, 10, 13, 16, 1, 4]:
    print(l_index)
    if(l_index != 1 and l_index != 4):
        nlist = cal_importance_conv(net_p_ortho, l_index)
    else:
        nlist = cal_importance_linear(net_p_ortho, l_index)
    imp_order_p_ortho = np.concatenate((imp_order_p_ortho,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1

with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p_ortho"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p_ortho, f)

#### Pruned network pruning

In [None]:
# ''' Correlated network '''
# orig_size = []
# for i in [2, 6, 10, 13, 16]:
#     orig_size.append(net_p.features[i].bias.shape[0])
# for i in [1, 4]:
#     orig_size.append(net_p.classifier[i].bias.shape[0])
# orig_size = np.array(orig_size)

In [None]:
''' De-Correlated network '''
orig_size = []
for i in [2, 6, 10, 13, 16]:
    orig_size.append(net_p_ortho.features[i].bias.shape[0])
for i in [1, 4]:
    orig_size.append(net_p_ortho.classifier[i].bias.shape[0])
orig_size = np.array(orig_size)

#### Pruning order

In [None]:
# ''' Correlated network '''
# order_p, prune_ratio = order_and_ratios(imp_order_p, 0.1)
# prune_ratio, orig_size

In [None]:
''' De-Correlated network '''
order_p, prune_ratio = order_and_ratios(imp_order_p_ortho, 0.1)
prune_ratio, orig_size

#### Define pruned network

In [None]:
prune_iter = 2

In [None]:
# ''' Correlated network pruning '''
# net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=1)

# print("Accs:", cal_acc(net_p1.eval()), cal_acc(net_p.eval()))
# print("Time:", cal_time(net_p1), t_corr)

In [None]:
''' De-Correlated network pruning '''
net_p1_ortho = pruner(net_p_ortho, order_p, prune_ratio, orig_size, net_type=2)

print("Accs:", cal_acc(net_p1_ortho.eval()), cal_acc(net_p_ortho.eval()))
print("Time:", cal_time(net_p1_ortho), t_decorr)

#### Save pruned network

In [None]:
# ''' Correlated network saving '''
# net_p = net_p1

# print('Saving..')
# state = {
#     'net_p': net_p.state_dict(),
#     'best_p_acc': cal_acc(net_p.eval())
# }
# if not os.path.isdir('net_p_checkpoint'):
#     os.mkdir('net_p_checkpoint')
# torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')

In [None]:
''' De-Correlated network saving '''
net_p_ortho = net_p1_ortho

print('Saving..')
state = {
    'net_p_ortho': net_p_ortho.state_dict(),
    'best_p_ortho_acc': cal_acc(net_p_ortho.eval())
}
if not os.path.isdir('ortho_p_checkpoint'):
    os.mkdir('ortho_p_checkpoint')
torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')

### Load pruned network

In [None]:
# ''' Correlated network loading '''
# with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(1)+".pkl", 'rb') as f:
#     cfg_p1 = pickle.load(f)
    
# net_p = AlexNet(cfg_p1).to(device)
# PATH = './net_p_checkpoint/ckpt'+str(1)+'.pth'
# net_p.load_state_dict(torch.load(PATH)['net_p'])

In [None]:
cal_acc(net_p.eval()), cal_acc(net_decorr.eval())

In [None]:
''' De-Correlated network loading '''
with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(1)+".pkl", 'rb') as f:
    cfg_p1 = pickle.load(f)

net_p_ortho = AlexNet(cfg_p1).to(device)
PATH = './ortho_p_checkpoint/ortho_ckpt'+str(2)+'.pth'
net_p_ortho.load_state_dict(torch.load(PATH)['net_p_ortho'])    

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

### FLOPS calculator

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_p_ortho, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
    
# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))    

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_decorr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_corr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))