In [None]:
l_index = 16
layer_id = 'bn'

In [1]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [2]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
         #(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

#classes = ('plane', 'car', 'bird', 'cat',
#          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified


In [4]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif x == 'D2':
                layers += [nn.Dropout(p=0.2)]
            elif x == 'D3':
                layers += [nn.Dropout(p=0.3)]
            elif x == 'D4':
                layers += [nn.Dropout(p=0.4)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
net = VGG('VGG13').to(device)
criterion = nn.CrossEntropyLoss()

In [7]:
def cal_acc(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %.4f %%' % (
        100 * correct / total))

In [8]:
def cal_acc_train(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in trainloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 50000 train images: %d %%' % (
        100 * correct / total))

In [9]:
def cal_mass(net, l_index):
    num_iter = 0
    r = 0.0
    with torch.no_grad():
        for i, data in enumerate(trainloader, 0):
            num_iter += 1
            if(num_iter == 40):
                break
            inputs, labels = data[0].to(device), data[1].to(device)
            L_self = 0.0
            L_mat = 0.0

            for epoch_num in range(1):
                out_features = net.features[0:l_index](inputs)
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
#                 X_t = X_t - X_t.mean(2).reshape(out_features.shape[0], out_features.shape[1], 1)
#                 X_t = torch.div(X_t, X_t.norm(dim=2).reshape(X_t.shape[0],X_t.shape[1],1) + 1e-10)
                cov_mat = torch.matmul(X_t, X_t.permute(0,2,1))
                L_mat = cov_mat.norm().pow(2)
                
                ident = (1 - torch.eye(out_features.shape[1])).to(device)
                cov_mat = cov_mat*ident
                L_self = cov_mat.norm().pow(2)
                
                r += 1 - L_self/L_mat

            del L_self, L_mat, out_features
            torch.cuda.empty_cache()
        return r/num_iter

### Correlated Net

In [None]:
PATH = './cifar100_net.pth'
# PATH = './tempnet1.pth'
net.load_state_dict(torch.load(PATH))
net = net.eval()

In [None]:
weight_base = net.features[l_index].weight.data.clone().detach()
bias_base = net.features[l_index].bias.data.clone().detach()

In [None]:
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break

In [None]:
loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)

        loss = (criterion(outputs, labels))

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss
    
    net.features[l_index].weight.data = weight_base.clone().detach()
    net.features[l_index].bias.data = bias_base.clone().detach()

# torch.save(loss_mat_corr, './decorr (features over samples)/loss_corrnet_bn_test_'+str(l_index)+'.pt')

# torch.save(loss_mat_corr, './w_decorr/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:
# loss_mat_corr = torch.load('./w_decorr/loss_corr_bn_train_'+str(l_index)+'.pt')
torch.save(loss_mat_corr,'./temp'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 50000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

### Decorrelated net

In [None]:
# PATH = './inner_decorr/dnet_all.pth'
PATH = './tempnet.pth'
net_decorr = VGG('VGG13').to(device)
net_decorr.load_state_dict(torch.load(PATH))
net_decorr = net_decorr.eval()

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
num_stop = 0
loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):        
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_mat_decorr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    num_stop = 0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        net_decorr.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        net_decorr.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        outputs = net_decorr(inputs)
        
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_decorr[n_index] = running_loss
    
    net_decorr.features[l_index].weight.data = weight_base.clone().detach()
    net_decorr.features[l_index].bias.data = bias_base.clone().detach()

torch.save(loss_mat_decorr, './inner_decorr/loss_bn_train_'+str(l_index)+'.pt')

In [None]:
# loss_mat_decorr = torch.load('./w_decorr/loss_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

# Net-Slim Train

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

# L2 based pruning Train

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_imp_corr.numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
w_imp_decorr = (w_imp_decorr - w_imp_decorr.min())
w_imp_decorr = w_imp_decorr/w_imp_decorr.max()
np.corrcoef(w_imp_decorr.numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

# Importance plots TFO Train

In [None]:
figure(figsize=(20,5))
s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Taylor FO)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Taylor FO)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Taylor FO)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

# Importance plots Netslim Train

In [None]:
figure(figsize=(20,5))

s = scale_corr.cpu().sort()[0].cpu().numpy()
order = scale_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Net-Slim)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = scale_decorr.cpu().sort()[0].cpu().numpy()
order = scale_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Net-Slim)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

# Importance plots L2 train

In [None]:
figure(figsize=(20,5))
s = w_imp_corr.sort()[0].cpu().numpy()
order = w_imp_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (L2)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = w_imp_decorr.sort()[0].cpu().numpy()
order = w_imp_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (L2)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

# Train networks

In [10]:
net = VGG('VGG13').to(device)
PATH = './cifar100_net.pth'
# PATH = './inner_decorr/onet_all.pth'
net.load_state_dict(torch.load(PATH))

# net_d = VGG('VGG13').to(device)
# PATH_d = './w_decorr/cifar100_w_decorr.pth'
# net_d.load_state_dict(torch.load(PATH_d))

<All keys matched successfully>

In [11]:
cal_acc(net.eval())

Accuracy of the network on the 10000 test images: 60.5200 %


In [12]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

In [13]:
l_temp = []

for layer_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    
    _, _, w_in, h_in = net.features[0:layer_index](torch.zeros(1,3,32,32).to(device)).shape
    
    c_out, c_in, w_f, h_f = net.features[layer_index-3].weight.shape
    
    l_temp.append((c_in*w_f*h_f)*(w_in*h_in)*c_out*(c_in*w_f*h_f*c_out**(1/5)))
    
    
l_temp = np.array(l_temp)
l_temp = l_temp/l_temp.sum()

l_imp = {}
i = 0
for layer_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    
    l_imp.update({layer_index : l_temp[i]})
    i+=1

### Channel-wise Inner product 

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
l_inds = [0, 6, 10, 13, 17, 20, 24, 27, 31, 34, 35]

for epoch in range(1):  
    running_loss = 0.0
    cov_loss = 0
    num_iter = 0
    av_cov_mass = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        out_features, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
    
        for epoch_num in range(1):            
            L_cov = 0.0
            for ind in range(len(l_inds)-1):
        
                out_features = net.features[l_inds[ind]:l_inds[ind+1]](out_features)
        
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
#                 X_t = X_t - X_t.mean(2).reshape(out_features.shape[0], out_features.shape[1], 1)
#                 X_t = torch.div(X_t, X_t.norm(dim=2).reshape(X_t.shape[0],X_t.shape[1],1) + 1e-15)
                cov_mat = torch.matmul(X_t, X_t.permute(0,2,1))
                L_cov += l_imp[layer_index]*(cov_mat - (torch.eye(out_features.shape[1])).to(device) / 1000).norm(1) / 128

            outputs = net.classifier(out_features.reshape(out_features.shape[0], -1))
            Lc = criterion(outputs, labels)
            loss = Lc + 1e-5*L_cov

            loss.backward()
            optimizer.step()
            
        # print statistics
            running_loss += loss.item()
            cov_loss += L_cov.item()
        
    print("Covariance loss: " + str(cov_loss/num_iter))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net.eval())
print('Finished Training')

### Scaled constraint

In [15]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
l_inds = [0, 2, 5, 9, 12, 16, 19, 23, 26, 30, 33]

for epoch in range(1):  
    running_loss = 0.0
    cov_loss = 0
    num_iter = 0
    av_cov_mass = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        out_features, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
    
        for epoch_num in range(1):            
            L_cov = 0.0
            for ind in range(len(l_inds)-1):
        
                out_features = net.features[l_inds[ind]:l_inds[ind+1]](out_features)
        
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
                x = net.features[l_inds[ind+1]].running_mean.data
                mu_ij = torch.matmul(x.reshape(x.shape[0], 1), x.reshape(x.shape[0], 1).t())
                
                x = net.features[l_inds[ind+1]].running_var.data
                sigma_ij = torch.matmul(x.reshape(x.shape[0], 1), x.reshape(x.shape[0], 1).t())

                x = net.features[l_inds[ind+1]].bias.data
                beta_ij = torch.matmul(x.reshape(x.shape[0], 1), x.reshape(x.shape[0], 1).t())

                x = net.features[l_inds[ind+1]].weight.data
                gamma_ij = torch.matmul(x.reshape(x.shape[0], 1), x.reshape(x.shape[0], 1).t())
                
                P = mu_ij - (beta_ij * (sigma_ij/ (gamma_ij + 1e-5)))
                
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
                cov_mat = torch.matmul(X_t, X_t.permute(0,2,1))
                L_cov += (cov_mat - P).norm(1) / 128
            
            out_features = net.features[33:](out_features)
            
            outputs = net.classifier(out_features.reshape(out_features.shape[0], -1))
            Lc = criterion(outputs, labels)
            loss = Lc + 1e-7*L_cov

            loss.backward()
            optimizer.step()
            
        # print statistics
            running_loss += loss.item()
            cov_loss += L_cov.item()
        
    print("Covariance loss: " + str(cov_loss/num_iter))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net.eval())
print('Finished Training')

Covariance loss: 1288020.9359015345
[1,   391] loss: 0.545
Accuracy of the network on the 10000 test images: 56.0800 %
Finished Training


In [None]:
cal_mass(net, 34)

In [None]:
(sigma_ij.abs()/(gamma_ij.abs()+1e-5)).min()

### Full data inner product

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
l_inds = [0, 3, 6, 10, 13, 17, 20, 24, 27, 31, 34, 35]

for epoch in range(1):  
    running_loss = 0.0
    cov_loss = 0
    num_iter = 0
    av_cov_mass = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        out_features, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
    
        for epoch_num in range(1):            
            L_cov = 0.0
            for ind in range(len(l_inds)-1):
        
                out_features = net.features[l_inds[ind]:l_inds[ind+1]](out_features)
        
                X_t = out_features.permute(1,0,2,3).reshape(out_features.shape[1], -1)
                cov_mat = torch.matmul(X_t, X_t.t())
                L_cov += l_imp[layer_index]*(cov_mat*(1-torch.eye(out_features.shape[1])).to(device)).norm(1)
            
            out_features = net.features[l_inds[-2]:](out_features)
            
            outputs = net.classifier(out_features.reshape(out_features.shape[0], -1))
            Lc = criterion(outputs, labels)
            loss = Lc + 1e-5*L_cov

            loss.backward()
            optimizer.step()
            
        # print statistics
            running_loss += loss.item()
            cov_loss += L_cov.item()
        
    print("Covariance loss: " + str(cov_loss/num_iter))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net.eval())
print('Finished Training')

In [None]:
cal_mass(net, 34)

In [None]:
def cal_mass(net, l_index):
    num_iter = 0
    r = 0.0
    with torch.no_grad():
        for i, data in enumerate(trainloader, 0):
            num_iter += 1
            if(num_iter == 40):
                break
            inputs, labels = data[0].to(device), data[1].to(device)
            L_self = 0.0
            L_mat = 0.0

        for epoch_num in range(1):            
            L_cov = 0.0
            for layer_index in [34]: # [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
        
                out_features = net.features[0:layer_index](inputs)
        
                X_t = out_features.permute(1,0,2,3).reshape(out_features.shape[1], -1)
                cov_mat = torch.matmul(X_t, X_t.t())
                L_mat = cov_mat.norm(1)
                L_self += l_imp[layer_index]*(cov_mat - (torch.eye(out_features.shape[1])).to(device)).norm(1)
                
                r += 1 - L_self/L_mat

            del L_self, L_mat, out_features
            torch.cuda.empty_cache()
        return r/num_iter

### VND regularization

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
gamma = 1
for epoch in range(3):  
    running_loss = 0.0
    cov_loss = 0
    num_iter = 0
    av_cov_mass = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        
        for epoch_num in range(1):            
            VND = 0.0
            for layer_index in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
        
                w_tensors = net.features[layer_index].weight.data
                w_mat = w_tensors.reshape(w_tensors.shape[0], -1)
                G = torch.matmul(w_mat, w_mat.t())
                
                VND += (G - torch.eye(w_mat.shape[0]).to(device)).norm(1) # + torch.trace(G) #+ gamma*(w_mat.norm(1))
#                 VND += torch.trace(G) - torch.log(torch.logdet(G + 1e-5*torch.eye(w_mat.shape[0]).to(device))) + gamma*(w_mat.norm(1))
        
            Lc = criterion(outputs, labels)
            loss = Lc + 1e-4*VND

            loss.backward()
            optimizer.step()
            
        # print statistics
            running_loss += loss.item()
            cov_loss += VND.item()
        
    print("Divergence loss: " + str(cov_loss/num_iter))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net.eval())
print('Finished Training')

In [None]:
for l in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    print("decorr: ", str(l), cal_mass(net, l).item())
    print("orig ", str(l), cal_mass(net1, l).item())

In [None]:
PATH = './tempnet1.pth'
torch.save(net.state_dict(), PATH)