In [None]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import os
from utils import progress_bar

In [None]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

In [None]:
trainloader.size

In [None]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):

    def __init__(self, cfg, classes=100):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, cfg[0], kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[0]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[1]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[2]),
            nn.Conv2d(cfg[2], cfg[3], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[3]),
            nn.Conv2d(cfg[3], cfg[4], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[4]),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(cfg[4] * 1 * 1, cfg[5]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(cfg[5], cfg[6]),
            nn.ReLU(inplace=True),
            nn.Linear(cfg[6], classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
def w_diag(net):
    ### Conv_ind == 0 ###
    w_mat = net.features[0].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.features[0].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(torch.t(params), params)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())
    
    for conv_ind in [6, 10, 13, 16]:
        w_mat = net.features[conv_ind-2].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net.features[conv_ind-2].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) 
        L_diag = (angle_mat.diag().norm(1))
        L_angle = (angle_mat.norm(1))
        print(L_diag.cpu()/L_angle.cpu())

    ### lin_ind = 1 ###        
    w_mat = net.classifier[1].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.classifier[1].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(torch.t(params), params)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

    ### lin_ind = 4 ###        
    w_mat = net.classifier[4].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.classifier[4].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(params, torch.t(params))
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
cfg = [64, 192, 384, 256, 256, 4096, 4096]
best_acc = 0

In [None]:
net = AlexNet(cfg).to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
# Training
def net_train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [None]:
from __future__ import print_function, absolute_import

__all__ = ['accuracy']

def kaccuracy(output, target, topk=(5,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def top5cal(net):
    net.eval()
    correct = 0
    total = 0
    top1 = 0
    top5 = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            acc1, acc5 = kaccuracy(outputs, targets, topk=(1, 5))
            top1 += (acc1.item()*inputs.shape[0])
            top5 += (acc5.item()*inputs.shape[0])
    top1 /= 10000
    top5 /= 10000
    
    print("top5", top5)

# Train baseline

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

In [None]:
net_dict = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(net_dict['net'])
best_acc = net_dict['best_acc']

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
# best_acc = 0

In [None]:
for whatev in range(1):
    for epoch in range(5):
        net_train(epoch+1)
        net_test(epoch+1)

#     net_dict = torch.load('./checkpoint/ckpt.pth')
#     net.load_state_dict(net_dict['net'])
#     best_acc = net_dict['best_acc']

#     optimizer = optim.Adam(net.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

# Inner product training

In [None]:
net_ortho = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_ortho.load_state_dict(net_dict['net'])
best_acc_ortho = net_dict['best_acc']

In [None]:
l_imp = {}

for conv_ind in [2, 6, 10, 13, 16]:
    l_imp.update({conv_ind: net.features[conv_ind].bias.shape[0]})
    
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
def net_test_ortho(epoch):
    global best_acc_ortho
    net_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_acc_ortho:
        print('Saving..')
        state = {
            'net': net_ortho.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('ortho_checkpoint'):
            os.mkdir('ortho_checkpoint')
        torch.save(state, './ortho_checkpoint/ckpt.pth')
        best_acc_ortho = acc

In [None]:
def net_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [6, 10, 13, 16]:
            w_mat = net_ortho.features[conv_ind-2].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind-2].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.Adam(net_ortho.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
for epoch in range(5):
    net_train_ortho(epoch)
    net_test_ortho(epoch)
    w_diag(net_ortho)

In [None]:
# PATH = './w_decorr/base_params/wnet_base_2.pth'
# torch.save(net.state_dict(), PATH)

# Importance analysis

### Layer index

In [None]:
l_index = 10
layer_id = 'bn'

### Correlated Net

In [None]:
net = AlexNet(cfg).to(device)
net_dict = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(net_dict['net'])
net = net.eval()

In [None]:
weight_base = net.features[l_index].weight.data.clone().detach()
bias_base = net.features[l_index].bias.data.clone().detach()

In [None]:
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [None]:
loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)

        loss = (criterion(outputs, labels))

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].weight.data = weight_base.clone().detach()
    net.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_corr, './w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')
loss_mat_corr = torch.load('./w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

### Decorrelated net

In [None]:
net_decorr = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_decorr.load_state_dict(net_dict['net'])
net_decorr = net_decorr.eval() 

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
num_stop = 0
loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):        
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_decorr = loss_base_decorr**2

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_mat_decorr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    num_stop = 0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        net_decorr.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        net_decorr.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        outputs = net_decorr(inputs)
        
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_decorr[n_index] = running_loss**2
    
    net_decorr.features[l_index].weight.data = weight_base.clone().detach()
    net_decorr.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_decorr, './w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')
loss_mat_decorr = torch.load('./w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).abs().pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

# Graphs

In [None]:
figure(figsize=(20,5))
s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Correlated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/corr/graphs/"+str(l_index)+".png")

In [None]:
figure(figsize=(20,5))
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Decorrelated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/decorr/graphs/"+str(l_index)+".png")

In [None]:
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_decorr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())
ortho_rms = ((loss_diff - s)**2).sum()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_corr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_corr - loss_base_corr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())

base_rms = ((loss_diff - s)**2).sum()

In [None]:
(ortho_rms, base_rms)

In [None]:
# rms_ortho = np.sqrt(np.array(rms_ortho) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))
# rms_base = np.sqrt(np.array(rms_base) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))

In [None]:
plt.figure(figsize=(10,5))
plt.bar(np.linspace(0,30,10)-0.5, rms_ortho, label="Decorrelated network")
plt.bar(np.linspace(0,30,10)+0.5, rms_base, label="Correlated network")
plt.xlabel("Layer ID")
plt.ylabel("RMS")
plt.legend()
plt.savefig("./w_decorr/loss_mats/rms.png")

## Subplots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20,5))

s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
axes[0].plot(s/s.max(), label="Estimated importance")
axes[0].set_title("Decorrelated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
axes[0].set_xlabel("Neuron index")
axes[0].set_ylabel("Normalized importance")
axes[0].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[0].legend()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
axes[1].plot(s/s.max(), label="Estimated importance")
axes[1].set_title("Correlated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
axes[1].set_xlabel("Neuron index")
axes[1].set_ylabel("Normalized importance")
axes[1].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[1].legend()

plt.savefig("./w_decorr/loss_mats/subplots/"+str(l_index)+".png")

# Other metrics

### Net-Slim Train

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### L2 based pruning Train

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_imp_corr.numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
w_imp_decorr = (w_imp_decorr - w_imp_decorr.min())
w_imp_decorr = w_imp_decorr/w_imp_decorr.max()
np.corrcoef(w_imp_decorr.numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

#### Importance plots Netslim Train

In [None]:
figure(figsize=(20,5))

s = scale_corr.cpu().sort()[0].cpu().numpy()
order = scale_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Net-Slim)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = scale_decorr.cpu().sort()[0].cpu().numpy()
order = scale_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Net-Slim)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

#### Importance plots L2 train

In [None]:
figure(figsize=(20,5))
s = w_imp_corr.sort()[0].cpu().numpy()
order = w_imp_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (L2)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = w_imp_decorr.sort()[0].cpu().numpy()
order = w_imp_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (L2)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())