In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import numpy as np

from models import *
from utils import progress_bar
from imp_baselines import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Files already downloaded and verified
Files already downloaded and verified


In [2]:
from ptflops import get_model_complexity_info

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


### Build and load base networks

In [4]:
print('==> Building model..')
net_corr = ResNet34()
net_decorr = ResNet34().to(device)
criterion = nn.CrossEntropyLoss()

net_corr = net_corr.to(device)
net_decorr = net_decorr.to(device)
if device == 'cuda':
    net_corr = torch.nn.DataParallel(net_corr)
    net_decorr = torch.nn.DataParallel(net_decorr)
    cudnn.benchmark = True

==> Building model..


In [5]:
PATH_corr = './w_decorr/base_params/cifar100_net.pth'
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])

PATH_decorr = './w_decorr/base_params/wnet_base.pth'
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net_ortho'])

<All keys matched successfully>

In [8]:
cal_importance_l2(net_corr, net_corr.module.layer1[0].conv1)

[array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
        13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
        26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
        39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
        52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.]),
 tensor([0.9281, 0.6417, 1.1833, 1.3572, 1.6598, 1.1922, 0.9523, 0.9194, 1.3004,
         0.9991, 0.6146, 1.8701, 0.9258, 0.9355, 1.6306, 1.8105, 1.6758, 1.1271,
         1.0802, 1.2220, 0.7697, 1.8095, 1.5406, 0.7243, 1.4830, 1.1796, 2.2444,
         0.6324, 0.8965, 2.5882, 1.1646, 0.8950, 0.8849, 1.2358, 1.0347, 2.0343,
         0.6978, 1.2911, 0.9901, 1.0696, 0.9263, 1.3304, 2.6852, 1.0501, 1.0119,
         1.5679, 0.8214, 1.2203, 1.4939, 1.8544, 2.5120, 2.4663, 1.1768, 1.0522,
         0.9579, 1.0658, 1.6419, 0.9822, 0.9519, 1.0726, 1.0564, 1.8252, 0.6968,
         1.0791], device='cuda:0')]

### Accuracy Calculator

In [None]:
def cal_acc(net_test):
    net_test.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_test(inputs)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(100 * correct / total)
        
    return 100 * correct / total

### Importance Calculator

In [None]:
def cal_importance_fisher(net, l_id, num_stop=100):
    num = 0
    bias_base = l_id.bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((l_id.weight.grad)*(l_id.weight.data)) + ((l_id.bias.grad)*(l_id.bias.data))).pow(2)
        num += labels.shape[0]
        if(num > num_stop):
            break
    
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    
    return neuron_order

In [None]:
def cal_importance_tfo(net, l_id, num_stop=100):
    num = 0
    bias_base = l_id.bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((l_id.weight.grad)*(l_id.weight.data)) + ((l_id.bias.grad)*(l_id.bias.data))).abs()
        num += labels.shape[0]
        if(num > num_stop):
            break
    
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    
    return neuron_order

In [None]:
def cal_importance_netslim(net, l_id, num_stop=100):
    running_loss = 0.0
    imp_corr_bn = l_id.weight.data.abs()    
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    return neuron_order

In [None]:
def cal_importance_netslim(net, l_id, num_stop=100):
    num = 0
    bias_base = l_id.bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = l_id.weight.data.abs()
    
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    
    return neuron_order

### Time Calculator

In [None]:
import time

In [None]:
def cal_time(net_acc):
    net_acc.eval()
    testsamp = torch.rand(1,3,32,32).to(device)
    
    for i in range(5):
        net_acc(testsamp)    
    t_end = 0
    t_s = time.time()
    for i in range(25):
        net_acc(testsamp)
        t_end += time.time() - t_s
    
    return (t_end / 25)

In [None]:
t_corr = cal_time(net_corr)
t_decorr = cal_time(net_decorr)

### TFO importance

In [None]:
num_stop = 100

In [None]:
import pickle

In [None]:
with open("./w_decorr/base_params/tfo_corr.pkl", 'rb') as f:
    imp_order_corr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_corr.parameters(), lr=0, weight_decay=0)
imp_order_corr = np.array([[],[],[]]).transpose()
i = 0

for l_index in net_corr.module.modules():
    if(isinstance(l_index, nn.BatchNorm2d)):
        print(l_index)
        nlist = cal_importance(net_corr, l_index, num_stop)
        imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
        i+=1
    
with open("./w_decorr/base_params/tfo_corr.pkl", 'wb') as f:
    pickle.dump(imp_order_corr, f)

In [None]:
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'rb') as f:
    imp_order_decorr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
imp_order_decorr = np.array([[],[],[]]).transpose()
i = 0

for l_index in net_corr.module.modules():
    if(isinstance(l_index, nn.BatchNorm2d)):
        print(l_index)
        nlist = cal_importance(net_decorr, l_index, num_stop)
        imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
        i+=1
    
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'wb') as f:
    pickle.dump(imp_order_decorr, f)

### Order and ratios for pruning

In [None]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in range(len(orig_size)):
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

In [None]:
def cal_size(net_size):
    orig_size = []
    for l_index in net_size.module.modules():
        if(isinstance(l_index, nn.BatchNorm2d)):
            orig_size.append(l_index.bias.shape[0])
    orig_size = np.array(orig_size)
    return orig_size

In [None]:
orig_size = cal_size(net_corr)

### Pruned Architecture

In [None]:
class BasicBlock_p(nn.Module):
    expansion = 1

    def __init__(self, lnum, in_planes, mid_planes, out_planes, stride=1):
        super(BasicBlock_p, self).__init__()

        if stride != 1:
            self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(mid_planes)
            self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(out_planes)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*out_planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*out_planes)
            )

        else:
            self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(mid_planes)
            self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(out_planes)
            self.shortcut = nn.Sequential()


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class ResNet34_p(nn.Module):
    def __init__(self, cfg, num_blocks, num_classes=100):
        super(ResNet34_p, self).__init__()
        self.in_planes = cfg['base']

        self.conv1 = nn.Conv2d(3, cfg['base'], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(cfg['base'])
        self.layer1 = self._make_layer(BasicBlock_p, cfg['l1'], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlock_p, cfg['l2'], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlock_p, cfg['l3'], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlock_p, cfg['l4'], num_blocks[3], stride=2)
        self.linear = nn.Linear(cfg['l4'][-1] * BasicBlock_p.expansion, num_classes)

    def _make_layer(self, block, l_cfg, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for lnum, stride in enumerate(strides):
            layers.append(block(lnum, self.in_planes, l_cfg[2*lnum], l_cfg[2*lnum+1], stride))
            self.in_planes = l_cfg[2*lnum+1]
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [None]:
def ResPruned(layer_cfgs):
    return ResNet34_p(layer_cfgs, num_blocks=[3,4,6,3])

### Pruning

In [None]:
def cfg_p(prune_ratio, orig_size, net_type=0):
    
    num_blocks=[3,4,6,3]
    cfg_list = {}
    cfg_list.update({'base': orig_size[0] - prune_ratio[0]})
    b_id = 0
    
    ### First layer has only identity shortcuts ###
    l_list = []
    l_id = 1
    for b_id in range(1,2*num_blocks[0]+1):
        if(b_id % 2 == 1):
            l_list.append(orig_size[b_id] - prune_ratio[b_id])
        else:
            l_list.append(orig_size[b_id])
    cfg_list.update({'l'+str(l_id): l_list.copy()})
    
    for l_id in range(2,len(num_blocks)+1):
        ### First block has a learned shortcut ###
        l_list = []
        b_id += 1
        l_list.append(orig_size[b_id] - prune_ratio[b_id])
        b_id += 1
        l_list.append(orig_size[b_id])
        b_id += 1
        
        ### Rest blocks have identity shortcuts ###
        for block_id in range(num_blocks[l_id-1]-1):
            b_id += 1
            l_list.append(orig_size[b_id] - prune_ratio[b_id])
            b_id += 1
            l_list.append(orig_size[b_id])
            
        cfg_list.update({'l'+str(l_id): l_list.copy()})

    if(net_type == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    elif(net_type == 2):
        with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    return cfg_list

In [None]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    
    cfg = cfg_p(prune_ratio, orig_size, net_type=net_type)
    print(cfg)
    net_pruned = torch.nn.DataParallel(ResPruned(cfg))

    block_id = 0

    # base
    n_c = orig_size[block_id] - cfg['base']
    order_c = np.sort(imp_order[block_id][n_c:])
    net_pruned.module.conv1.weight.data = net.module.conv1.weight[order_c].data.detach().clone()
    
    net_pruned.module.bn1.weight.data = net.module.bn1.weight[order_c].data.detach().clone()
    net_pruned.module.bn1.bias.data = net.module.bn1.bias[order_c].data.detach().clone()
    net_pruned.module.bn1.running_var.data = net.module.bn1.running_var[order_c].data.detach().clone()
    net_pruned.module.bn1.running_mean.data = net.module.bn1.running_mean[order_c].data.detach().clone()
    order_p = order_c.copy()
    block_id += 1

    ### l1
    for block_num in range(int(len(cfg['l1']) / 2)):
        n_c = orig_size[block_id] - cfg['l1'][2*block_num]
        order_c = np.sort(imp_order[block_id][n_c:])
        net_pruned.module.layer1[block_num].conv1.weight.data = net.module.layer1[block_num].conv1.weight[order_c][:,order_p].data.detach().clone()

        net_pruned.module.layer1[block_num].bn1.weight.data = net.module.layer1[block_num].bn1.weight[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn1.bias.data = net.module.layer1[block_num].bn1.bias[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn1.running_var.data = net.module.layer1[block_num].bn1.running_var[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn1.running_mean.data = net.module.layer1[block_num].bn1.running_mean[order_c].data.detach().clone()
        order_p = order_c.copy()
        block_id += 1

        n_c = orig_size[block_id] - cfg['l1'][2*block_num+1]
        order_c = np.sort(imp_order[block_id][n_c:])

        net_pruned.module.layer1[block_num].conv2.weight.data = net.module.layer1[block_num].conv2.weight[order_c][:,order_p].data.detach().clone()
        
        net_pruned.module.layer1[block_num].bn2.weight.data = net.module.layer1[block_num].bn2.weight[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn2.bias.data = net.module.layer1[block_num].bn2.bias[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn2.running_var.data = net.module.layer1[block_num].bn2.running_var[order_c].data.detach().clone()
        net_pruned.module.layer1[block_num].bn2.running_mean.data = net.module.layer1[block_num].bn2.running_mean[order_c].data.detach().clone()
        order_p = order_c.copy()
        block_id += 1

    ### l2
    for block_num in range(int(len(cfg['l2']) / 2)):
        n_c = orig_size[block_id] - cfg['l2'][2*block_num]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer2[block_num].conv1.weight.data = net.module.layer2[block_num].conv1.weight[order_c][:,order_p].data.detach().clone()
        ### BN
        net_pruned.module.layer2[block_num].bn1.weight.data = net.module.layer2[block_num].bn1.weight[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn1.bias.data = net.module.layer2[block_num].bn1.bias[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn1.running_var.data = net.module.layer2[block_num].bn1.running_var[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn1.running_mean.data = net.module.layer2[block_num].bn1.running_mean[order_c].data.detach().clone()
        order_p0 = order_p.copy()
        order_p = order_c.copy()
        block_id += 1

        n_c = orig_size[block_id] - cfg['l2'][2*block_num+1]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer2[block_num].conv2.weight.data = net.module.layer2[block_num].conv2.weight[order_c][:,order_p].data.detach().clone()
        ### BN        
        net_pruned.module.layer2[block_num].bn2.weight.data = net.module.layer2[block_num].bn2.weight[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn2.bias.data = net.module.layer2[block_num].bn2.bias[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn2.running_var.data = net.module.layer2[block_num].bn2.running_var[order_c].data.detach().clone()
        net_pruned.module.layer2[block_num].bn2.running_mean.data = net.module.layer2[block_num].bn2.running_mean[order_c].data.detach().clone()
        order_p = order_c.copy()
        block_id += 1

        ### Shortcut
        if(block_num == 0):
            n_c = orig_size[block_id] - cfg['l2'][2*block_num+1]
            order_c = np.sort(imp_order[block_id][n_c:])
            ### Conv
            net_pruned.module.layer2[block_num].shortcut[0].weight.data = net.module.layer2[block_num].shortcut[0].weight[order_c][:,order_p0].data.detach().clone()
            ### BN        
            net_pruned.module.layer2[block_num].shortcut[1].weight.data = net.module.layer2[block_num].shortcut[1].weight[order_c].data.detach().clone()
            net_pruned.module.layer2[block_num].shortcut[1].bias.data = net.module.layer2[block_num].shortcut[1].bias[order_c].data.detach().clone()
            net_pruned.module.layer2[block_num].shortcut[1].running_var.data = net.module.layer2[block_num].shortcut[1].running_var[order_c].data.detach().clone()
            net_pruned.module.layer2[block_num].shortcut[1].running_mean.data = net.module.layer2[block_num].shortcut[1].running_mean[order_c].data.detach().clone()
            block_id += 1

    ### l3
    for block_num in range(int(len(cfg['l3']) / 2)):
        n_c = orig_size[block_id] - cfg['l3'][2*block_num]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer3[block_num].conv1.weight.data = net.module.layer3[block_num].conv1.weight[order_c][:,order_p].data.detach().clone()
        ### BN
        net_pruned.module.layer3[block_num].bn1.weight.data = net.module.layer3[block_num].bn1.weight[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn1.bias.data = net.module.layer3[block_num].bn1.bias[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn1.running_var.data = net.module.layer3[block_num].bn1.running_var[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn1.running_mean.data = net.module.layer3[block_num].bn1.running_mean[order_c].data.detach().clone()
        order_p0 = order_p.copy()
        order_p = order_c.copy()
        block_id += 1

        n_c = orig_size[block_id] - cfg['l3'][2*block_num+1]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer3[block_num].conv2.weight.data = net.module.layer3[block_num].conv2.weight[order_c][:,order_p].data.detach().clone()
        ### BN        
        net_pruned.module.layer3[block_num].bn2.weight.data = net.module.layer3[block_num].bn2.weight[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn2.bias.data = net.module.layer3[block_num].bn2.bias[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn2.running_var.data = net.module.layer3[block_num].bn2.running_var[order_c].data.detach().clone()
        net_pruned.module.layer3[block_num].bn2.running_mean.data = net.module.layer3[block_num].bn2.running_mean[order_c].data.detach().clone()
        order_p = order_c.copy()
        block_id += 1

        ### Shortcut
        if(block_num == 0):
            n_c = orig_size[block_id] - cfg['l3'][2*block_num+1]
            order_c = np.sort(imp_order[block_id][n_c:])
            ### Conv
            net_pruned.module.layer3[block_num].shortcut[0].weight.data = net.module.layer3[block_num].shortcut[0].weight[order_c][:,order_p0].data.detach().clone()
            ### BN        
            net_pruned.module.layer3[block_num].shortcut[1].weight.data = net.module.layer3[block_num].shortcut[1].weight[order_c].data.detach().clone()
            net_pruned.module.layer3[block_num].shortcut[1].bias.data = net.module.layer3[block_num].shortcut[1].bias[order_c].data.detach().clone()
            net_pruned.module.layer3[block_num].shortcut[1].running_var.data = net.module.layer3[block_num].shortcut[1].running_var[order_c].data.detach().clone()
            net_pruned.module.layer3[block_num].shortcut[1].running_mean.data = net.module.layer3[block_num].shortcut[1].running_mean[order_c].data.detach().clone()
            block_id += 1

    ### l4
    for block_num in range(int(len(cfg['l4']) / 2)):
        n_c = orig_size[block_id] - cfg['l4'][2*block_num]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer4[block_num].conv1.weight.data = net.module.layer4[block_num].conv1.weight[order_c][:,order_p].data.detach().clone()
        ### BN
        net_pruned.module.layer4[block_num].bn1.weight.data = net.module.layer4[block_num].bn1.weight[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn1.bias.data = net.module.layer4[block_num].bn1.bias[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn1.running_var.data = net.module.layer4[block_num].bn1.running_var[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn1.running_mean.data = net.module.layer4[block_num].bn1.running_mean[order_c].data.detach().clone()
        order_p0 = order_p.copy()
        order_p = order_c.copy()
        block_id += 1

        n_c = orig_size[block_id] - cfg['l4'][2*block_num+1]
        order_c = np.sort(imp_order[block_id][n_c:])
        ### Conv
        net_pruned.module.layer4[block_num].conv2.weight.data = net.module.layer4[block_num].conv2.weight[order_c][:,order_p].data.detach().clone()
        ### BN        
        net_pruned.module.layer4[block_num].bn2.weight.data = net.module.layer4[block_num].bn2.weight[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn2.bias.data = net.module.layer4[block_num].bn2.bias[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn2.running_var.data = net.module.layer4[block_num].bn2.running_var[order_c].data.detach().clone()
        net_pruned.module.layer4[block_num].bn2.running_mean.data = net.module.layer4[block_num].bn2.running_mean[order_c].data.detach().clone()
        order_p = order_c.copy()
        block_id += 1

        ### Shortcut
        if(block_num == 0):
            n_c = orig_size[block_id] - cfg['l4'][2*block_num+1]
            order_c = np.sort(imp_order[block_id][n_c:])
            ### Conv
            net_pruned.module.layer4[block_num].shortcut[0].weight.data = net.module.layer4[block_num].shortcut[0].weight[order_c][:,order_p0].data.detach().clone()
            ### BN        
            net_pruned.module.layer4[block_num].shortcut[1].weight.data = net.module.layer4[block_num].shortcut[1].weight[order_c].data.detach().clone()
            net_pruned.module.layer4[block_num].shortcut[1].bias.data = net.module.layer4[block_num].shortcut[1].bias[order_c].data.detach().clone()
            net_pruned.module.layer4[block_num].shortcut[1].running_var.data = net.module.layer4[block_num].shortcut[1].running_var[order_c].data.detach().clone()
            net_pruned.module.layer4[block_num].shortcut[1].running_mean.data = net.module.layer4[block_num].shortcut[1].running_mean[order_c].data.detach().clone()
            block_id += 1

    ### Linear block
    net_pruned.module.linear.weight.data = net.module.linear.weight.data.detach().clone()
    net_pruned.module.linear.bias.data = net.module.linear.bias.data.detach().clone()
    
    return net_pruned

## Retraining

In [None]:
prune_iter = 1

### Correlated network pruning

In [None]:
orig_size = cal_size(net_corr)

In [None]:
order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.3)
np.array(prune_ratio), orig_size

In [None]:
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])
net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

### Accs

In [None]:
cal_acc(net_corr), cal_acc(net_p)

In [None]:
t_corr = cal_time(net_corr)
t_p = cal_time(net_p)

In [None]:
100*(1 - (t_p / t_corr))

#### Retraining

In [None]:
# Training
def net_p_train(epoch):
    print('\nEpoch: %d' % epoch)
    net_p.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_p_test(epoch):
    global best_p_acc
    global prune_iter
    net_p.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_p_acc:
        print('Saving..')
        state = {
            'net_p': net_p.state_dict(),
            'best_p_acc': acc
        }
        if not os.path.isdir('net_p_checkpoint'):
            os.mkdir('net_p_checkpoint')
        torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
        best_p_acc = acc

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
best_p_acc = 0

In [None]:
for epoch in range(1):
    net_p_train(epoch)
    net_p_test(epoch)

#### Save correlated pruned network

In [None]:
prune_iter

In [None]:
net_dict = torch.load('./net_p_checkpoint/ckpt1.pth')
net_p.load_state_dict(net_dict['net_p'])
best_p_acc = net_dict['best_p_acc']

### Decorrelated network pruning

In [None]:
orig_size = cal_size(net_decorr)

#### Pruning order

In [None]:
order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, 0.1)
np.array(prune_ratio), orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net_ortho'])
net_p_ortho = pruner(net_decorr, order_decorr, prune_ratio, orig_size, net_type=2)

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

In [None]:
t_decorr = cal_time(net_decorr)
t_p_ortho = cal_time(net_p_ortho)

In [None]:
100*(1 - (t_p_ortho / t_corr))

#### Retraining

In [None]:
num_blocks = [3,4,6,3]

In [None]:
l_imp = {-1:{'conv1':net_ortho.module.bn1.bias.shape[0]}, 0:{}, 1:{}, 2:{}, 3:{}}

mod_id = 0
for module_id in [net_ortho.module.layer1, net_ortho.module.layer2, net_ortho.module.layer3, net_ortho.module.layer4]:
    for b_id in range(num_blocks[mod_id]):
        l_imp[mod_id].update({2*b_id: module_id[b_id].bn1.bias.shape[0]})
        l_imp[mod_id].update({2*b_id+1: module_id[b_id].bn2.bias.shape[0]})
    mod_id += 1
    
normalizer = 0
for key, val in l_imp.items():
    for key1, val1 in val.items():
        normalizer += val1
for key, val in l_imp.items():
    for key1, val1 in val.items():
        l_imp[key][key1] /= normalizer

In [None]:
# Training
def net_p_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_p_ortho.train()
    running_loss = 0
    correct = 0
    total = 0
    angle_cost = 0.0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, labels = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p_ortho(inputs)

        L_angle = 0

        ### Conv_ind == 0 ###
        w_mat = net_p_ortho.module.conv1.weight
        params = (w_mat.reshape(w_mat.shape[0],-1))
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += l_imp[-1]['conv1']*(angle_mat).norm(1) #.norm().pow(2))
        
        ### Conv_ind != 0 ###
        mod_id = 0
        for module_id in [net_p_ortho.module.layer1, net_p_ortho.module.layer2, net_p_ortho.module.layer3, net_p_ortho.module.layer4]:
            for b_id in range(num_blocks[mod_id]):
                w_mat = module_id[b_id].conv1.weight
                params = (w_mat.reshape(w_mat.shape[0],-1))
                angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
                L_angle += l_imp[mod_id][2*b_id]*(angle_mat).norm(1)

                w_mat = module_id[b_id].conv2.weight
                params = (w_mat.reshape(w_mat.shape[0],-1))
                angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
                L_angle += l_imp[mod_id][2*b_id+1]*(angle_mat).norm(1)

                try:
                    w_mat = module_id[b_id].shortcut[0]
                    params = (w_mat.reshape(w_mat.shape[0],-1))
                    angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[1]).to(device)
                    L_angle += l_imp[mod_id][2*b_id]*(angle_mat).norm(1)
                except:
                    pass
            mod_id += 1
                
        Lc = criterion(outputs, labels)
        loss = (1e-2)*(L_angle) + Lc

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        angle_cost += (L_angle).item()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))

    print("angle_cost: ", angle_cost/batch_idx+1)

In [None]:
def net_p_test_ortho(epoch):
    global best_p_ortho_acc
    global prune_iter
    net_p_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_p_ortho_acc:
        print('Saving..')
        state = {
            'net_p_ortho': net_p_ortho.state_dict(),
            'best_p_ortho_acc': acc
        }
        if not os.path.isdir('ortho_p_checkpoint'):
            os.mkdir('ortho_p_checkpoint')
        torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
        best_p_ortho_acc = acc

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p_ortho.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
# best_p_ortho_acc = 0

In [None]:
for epoch in range(1):
    net_p_train_ortho(epoch)
    net_p_test_ortho(epoch)
    w_diag()

#### Save decorrelated pruned network

In [None]:
prune_iter 

In [None]:
net_dict = torch.load('./ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
net_p_ortho.load_state_dict(net_dict['net_p_ortho'])

#### Evaluate orthogonality of filters in pruned network

In [None]:
def w_diag():
    
    ### Conv_ind == 0 ###
    w_mat = net_ortho.module.conv1.weight
    params = (w_mat.reshape(w_mat.shape[0],-1))
#     print(params.shape)
    angle_mat = torch.matmul(torch.t(params), params) # - torch.eye(params.shape[1]).to(device)
#     print(angle_mat.shape)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

    ### Conv_ind != 0 ###
    mod_id = 0
    for module_id in [net_ortho.module.layer1, net_ortho.module.layer2, net_ortho.module.layer3, net_ortho.module.layer4]:
        for b_id in range(num_blocks[mod_id]):
            w_mat = module_id[b_id].conv1.weight
            params = (w_mat.reshape(w_mat.shape[0],-1))
#             print(params.shape)
            angle_mat = torch.matmul(params, torch.t(params)) # - torch.eye(params.shape[0]).to(device)
#             print(angle_mat.shape)
            L_diag = (angle_mat.diag().norm(1))
            L_angle = (angle_mat.norm(1))
            print(L_diag.cpu()/L_angle.cpu())                

            w_mat = module_id[b_id].conv2.weight
            params = (w_mat.reshape(w_mat.shape[0],-1))
#             print(params.shape)
            angle_mat = torch.matmul(params, torch.t(params)) # - torch.eye(params.shape[0]).to(device)
#             print(angle_mat.shape)
            L_diag = (angle_mat.diag().norm(1))
            L_angle = (angle_mat.norm(1))
            print(L_diag.cpu()/L_angle.cpu())                

            try:
                w_mat = module_id[b_id].shortcut[0]
                params = (w_mat.reshape(w_mat.shape[0],-1))
                angle_mat = torch.matmul(params, torch.t(params)) # - torch.eye(params.shape[1]).to(device)
                L_diag = (angle_mat.diag().norm(1))
                L_angle = (angle_mat.norm(1))
                print(L_diag.cpu()/L_angle.cpu())                
            except:
                pass
        mod_id += 1

### Subsequent pruning

#### Importance

In [None]:
''' Correlated network '''
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'rb') as f:
#     imp_order_p = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
imp_order_p = np.array([[],[],[]]).transpose()
i = 0

for l_index in net_p.module.modules():
    if(isinstance(l_index, nn.BatchNorm2d)):
        print(l_index)
        nlist = cal_importance(net_p, l_index)
        imp_order_corr = np.concatenate((imp_order_p,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
        i+=1
    
with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p, f)

In [None]:
''' De-Correlated network '''
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'rb') as f:
    imp_order_p_ortho = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p_ortho.parameters(), lr=0, weight_decay=0)
imp_order_p_ortho = np.array([[],[],[]]).transpose()
i = 0

for l_index in net_p_ortho.module.modules():
    if(isinstance(l_index, nn.BatchNorm2d)):
        print(l_index)
        nlist = cal_importance(net_p_ortho, l_index)
        imp_order_corr = np.concatenate((imp_order_p_ortho,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
        i+=1
    
with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p_ortho"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p_ortho, f)

#### Pruned network pruning

In [None]:
# ''' Correlated network '''
# orig_size = cal_size(net_p)

In [None]:
''' De-Correlated network '''
orig_size = cal_size(net_p_ortho)

#### Pruning order

In [None]:
# ''' Correlated network '''
# order_p, prune_ratio = order_and_ratios(imp_order_p, 0.1)
# np.array(prune_ratio), orig_size

In [None]:
''' De-Correlated network '''
order_p, prune_ratio = order_and_ratios(imp_order_p_ortho, 0.1)
np.array(prune_ratio), orig_size

#### Define pruned network

In [None]:
prune_iter = 2

In [None]:
# ''' Correlated network pruning '''
# net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=1)

# print("Accs:", cal_acc(net_p1.eval()), cal_acc(net_p.eval()))
# print("Time:", cal_time(net_p1), t_corr)

In [None]:
''' De-Correlated network pruning '''
net_p1_ortho = pruner(net_p_ortho, order_p, prune_ratio, orig_size, net_type=2)

print("Accs:", cal_acc(net_p1_ortho.eval()), cal_acc(net_p_ortho.eval()))
print("Time:", cal_time(net_p1_ortho), t_decorr)

#### Prune the pruned network again

In [None]:
# ''' Correlated network saving '''
# net_p = net_p1

# print('Saving..')
# state = {
#     'net_p': net_p.state_dict(),
#     'best_p_acc': cal_acc(net_p.eval())
# }
# if not os.path.isdir('net_p_checkpoint'):
#     os.mkdir('net_p_checkpoint')
# torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')

In [None]:
''' De-Correlated network saving '''
net_p_ortho = net_p1_ortho

print('Saving..')
state = {
    'net_p_ortho': net_p_ortho.state_dict(),
    'best_p_ortho_acc': cal_acc(net_p_ortho.eval())
}
if not os.path.isdir('ortho_p_checkpoint'):
    os.mkdir('ortho_p_checkpoint')
torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')

### Load saved network

In [None]:
# ''' Correlated network loading '''
# with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(1)+".pkl", 'rb') as f:
#     cfg_p1 = pickle.load(f)
    
# net_p = torch.nn.DataParallel(MobileNet_p(cfg_p1[0], cfg_p1[1:]))
# PATH = './net_p_checkpoint/ckpt'+str(1)+'.pth'
# net_p.load_state_dict(torch.load(PATH)['net_p'])

In [None]:
# cal_acc(net_p.eval()), cal_acc(net_corr.eval())

In [None]:
''' De-Correlated network loading '''
with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(1)+".pkl", 'rb') as f:
    cfg_p1 = pickle.load(f)

net_p_ortho = torch.nn.DataParallel(MobileNet_p(cfg_p1[0], cfg_p1[1:]))
PATH = './ortho_p_checkpoint/ortho_ckpt'+str(1)+'.pth'
net_p_ortho.load_state_dict(torch.load(PATH)['net_p_ortho'])    

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

### FLOPS calculator

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_p_ortho, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
    
# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))    

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_decorr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_corr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))