# A novel loss function to improve generalisation in neural networks

This loss function is the basis of my dissertation. We seek to support the generalisation of neural networks through a novel loss function: Correlation Penalty Loss (CP Loss). We derive a correlation measure between a network’s activations and a set of labels associated with an orthogonal task. For example, the speaker’s gender is orthogonal in a word recognition task. 

**Experimental work**
We implement CPL and experiment with audio and image classification tasks using Convolutional Neural Networks. In addition to using existing orthogonal tasks, we use data augmentation strategies, such as adding noise to create orthogonal labels. This latter point relieves us from the constraint of sourcing or the expense of collecting data with these additional labels. 

**Results**
The experimental results showed some positive effects of using CP Loss, with a statistically significant result in reducing the standard deviation of the accuracy scores between orthogonal classes in an image classification task (using CIFAR-10 and labels generated by data augmentation).

**Credits**
My supervisor for this project was Dr Tillman Weyde (https://www.city.ac.uk/about/people/academics/tillman-weyde) and many thanks to Dr Eric Guizzo for his support and whose work inspired this project (https://arxiv.org/abs/2006.06494)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time

from cp_loss import CorrelationPenaltyLoss 

## Toy example

In [8]:
# create a simple MLP
# takes 10 input features and outputs 10 features
class MyMLP(nn.Module):
    def __init__(self):
        super(MyMLP, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 20)
        self.fc3 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [None]:
# random dataset
# input
input = torch.randn(64, 10)
v_input = torch.randn(16,10)
t_input = torch.randn(8,10)

# "true" labels int 0-9
target = torch.randint(0, 10, (64,))
v_target = torch.randint(0, 10, (16,))
t_target = torch.randint(0, 10, (8,))

# orthogonal labels int 0-2
orthogonal_labels = torch.randint(0, 3, (64,))
v_orthogonal_labels = torch.randint(0, 3, (16,))
t_orthogonal_labels = torch.randint(0, 3, (8,))

In [18]:
# set up training
model = MyMLP()
base_loss_fn = nn.CrossEntropyLoss()
criterion=CorrelationPenaltyLoss(base_loss_fn=base_loss_fn, 
                                 model=model,
                                 layer_type="nn.Linear",
                                 num_orthog_classes=3,
                                 one_hot=True,
                                 alpha=10)
optimizer = optim.Adam(params=model.parameters())

n_epochs = 10

# for validation
val_err = torch.zeros(n_epochs)
 
for epoch in range(n_epochs):
    t0=time.time()
    model.train()
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output, target, orthogonal_labels)
    
    loss.backward()
    optimizer.step()
    
    # validation loss
    model.eval()
    running_vloss = 0.
    with torch.no_grad():
        v_output = model(v_input)
        vloss = base_loss_fn(v_output, v_target)
        running_vloss += vloss

    avg_vloss = running_vloss / (epoch + 1)
    val_err[epoch]=avg_vloss
    print(f"epoch: {epoch+1}, validation loss: {avg_vloss:0.2f}, time:{time.time()-t0:0.2f}")

epoch: 1, validation loss: 2.36, time:0.03
epoch: 2, validation loss: 1.18, time:0.01
epoch: 3, validation loss: 0.79, time:0.00
epoch: 4, validation loss: 0.59, time:0.00
epoch: 5, validation loss: 0.47, time:0.01
epoch: 6, validation loss: 0.39, time:0.00
epoch: 7, validation loss: 0.34, time:0.00
epoch: 8, validation loss: 0.29, time:0.00
epoch: 9, validation loss: 0.26, time:0.00
epoch: 10, validation loss: 0.24, time:0.00


In [None]:
# accuracy - should not be better than random chance (~10%)! 
num_correct = 0
num_samples = 0
model.eval()
 
with torch.no_grad():  
    scores=model(t_input)
    _, predictions = scores.max(1)
    num_correct += (predictions == t_target).sum()
    num_samples += predictions.size(0)

print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}') 

Got 1 / 8 with accuracy 12.50
