# Train a neural network to predict MHC ligands
The notebook consists of the following sections:

0. Module imports, define functions, set constants
1. Load Data
2. Build Model
3. Select Hyper-paramerters
4. Compile Model
5. Train Model
6. Evaluation

## Exercise

The exercise is to optimize the model given in this notebook by selecting hyper-parameters that improve performance. First run the notebook as is and take notes of the performance (AUC, MCC). Then start a manual hyper-parameter search by following the instructions below. If your first run results in poor fitting (the model doesn't learn anything during training) do not dispair! Hopefully you will see a rapid improvement when you start testing other hyper-parameters.

### Optimizer, learning rate, and mini-batches
The [optimizers](https://pytorch.org/docs/stable/optim.html) are different approaches of minimizing a loss function based on gradients. The learning rate determine to which degree we correct the weights. The smaller the learning rate, the smaller corrections we make. This may prolong the training time. To mitigate this, one can train with mini-batches. Instead of feeding your network all of the data before you make updates you can partition the training data into mini-batches and update weigths more frequently. Thus, your model might converge faster. Also small batch sizes use less memory, which means you can train a model with more parameters.

If you experienced trouble in even training then you might benefit from lowering the learning rate to 0.01 or 0.001 or perhaps even smaller.

__Optimizers:__
1. SGD (+ Momentum)
2. Adam
3. Try others if you like...

__Mini-batch size:__
When you have implemented and tested a smaller learning rate try also implementing a mini-batch of size 512 or 128. In order to set the mini-batch size use the variable MINI_BATCH_SIZE and run train_with_minibatches() instead of train().

### Number of hidden units
Try increasing the number of model parameters (weights), eg. 64, 128, or 512.

### Hidden layers
Add another layer to the network. To do so you must edit the methods of Net()-class.

### Parameter initialization
Parameter initialization can be extremely important.
PyTorch has a lot of different [initializers](http://pytorch.org/docs/master/nn.html#torch-nn-init) and the most often used initializers are listed below. Try implementing one of them.
1. Kaming He
2. Xavier Glorot
3. Uniform or Normal with small scale (0.1 - 0.01)

Bias is nearly always initialized to zero using the [torch.nn.init.constant(tensor, val)](http://pytorch.org/docs/master/nn.html#torch.nn.init.constant)

To implement an initialization method you must uncomment #net.apply(init_weights) and to select your favorite method you must modify the init_weights function.

### Nonlinearity
Non-linearity is what makes neural networks universal predictors. Not everything in our universe is related by linearity and therefore we must implement non-linear activations to cope with that. [The most commonly used nonliearities](http://pytorch.org/docs/master/nn.html#non-linear-activations) are listed below. 
1. ReLU
2. Leaky ReLU
3. Sigmoid squash the output [0, 1], and are used if your output is binary (not used in the hidden layers)
4. Tanh is similar to sigmoid, but squashes in [-1, 1]. It is rarely used any more.
5. Softmax normalizes the the output to 1, and is used as output if you have a classification problem

Change the current function to another. To do so, you must modify the forward()-method in the Net()-class. 

### Early stopping
Early stopping stops your training when you have reached the best possible model before overfitting. The method saves the model weights at each epoch while constantly monitoring the development of the validation loss. Once the validation loss starts to increase the method will raise a flag. The method will allow for a number of epochs to pass before stopping. The number of epochs are referred to as patience. If the validation loss decreases below the previous global minima before the patience runs out the flag and patience is reset. If a new global minima is not encountered the training is stopped and the weights from the global minima epoch are loaded and defines the final model. 

To implement early stopping you must set implement=True in the invoke()-function called within train() or train_with_minibatches().

### Regularization (optional)
Implement either L2 regularization, [dropout](https://pytorch.org/docs/stable/nn.html#dropout-layers) or [batch normalization](https://pytorch.org/docs/stable/nn.html#normalization-layers).

### Mix of peptide lengths
Now you have hopefully found an architecture that yields a pretty good performance. But of course it is not that simple... One of the issues that occur when working with real data is that ligands can have lengths of 8, 10, or 11 amino acids. In order to accomodate different lengths you need to pad your sequences, so they still fit into the expected tensor. This, however, may mess with the weights of the anchor positions.

Try and include 8-9-10-11mers and take a look at how it affects performance. 

* set MAX_PEP_SEQ_LEN = 11
* set ALLELE = 'A0301'

#### Performance evaluation
Run the notebook and take a look at how the model performs on data partitioned by peptide length. 

1. What happens to the performance evaluated on 8-10-11mers (excluding 9mers) compared to performance evaluated only on peptides of length 9?

Can you explain why we would prefer a good performance on 8-9-10-11mers over a higher performance on only 9mers?

## ... continue exercise with notebook CNN-ligand_prediction

In [None]:
import torch
from torch.autograd import Variable
import torch.nn as nn
#import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [None]:
from pytorchtools import EarlyStopping

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, matthews_corrcoef

In [None]:
SEED=1
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
def load_blosum(filename):
    """
    Read in BLOSUM values into matrix.
    """
    aa = ['A', 'R', 'N' ,'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'X']
    df = pd.read_csv(filename, sep='\s+', comment='#', index_col=0)
    return df.loc[aa, aa]

In [None]:
def load_peptide_target(filename):
    """
    Read amino acid sequence of peptides and
    corresponding log transformed IC50 binding values from text file.
    """
    df = pd.read_csv(filename, sep='\s+', usecols=[0,1], names=['peptide','target'])
    return df.sort_values(by='target', ascending=False).reset_index(drop=True)

In [None]:
def encode_peptides(Xin):
    """
    Encode AA seq of peptides using BLOSUM50.
    Returns a tensor of encoded peptides of shape (batch_size, MAX_PEP_SEQ_LEN, n_features)
    """
    blosum = load_blosum(blosum_file)
    
    batch_size = len(Xin)
    n_features = len(blosum)
    
    Xout = np.zeros((batch_size, MAX_PEP_SEQ_LEN, n_features), dtype=np.int8)
    
    for peptide_index, row in Xin.iterrows():
        for aa_index in range(len(row.peptide)):
            Xout[peptide_index, aa_index] = blosum[ row.peptide[aa_index] ].values
            
    return Xout, Xin.target.values

In [None]:
def invoke(early_stopping, loss, model, implement=False):
    if implement == False:
        return False
    else:
        early_stopping(loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            return True

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Arguments

In [None]:
MAX_PEP_SEQ_LEN = 9 #11
BINDER_THRESHOLD = 0.426

# Main

## Load

In [None]:
ALLELE = 'A0201' #'A0301'

In [None]:
blosum_file = "../data/BLOSUM50"
train_data = "../data/%s/train_BA" % ALLELE
valid_data = "../data/%s/valid_BA" % ALLELE
test_data = "../data/%s/test_BA" % ALLELE

In [None]:
train_raw = load_peptide_target(train_data)
valid_raw = load_peptide_target(valid_data)
test_raw = load_peptide_target(test_data)

### Visualize Data

In [None]:
def plot_peptide_distribution(raw_data, raw_set):
    raw_data['peptide_length'] = raw_data.peptide.str.len()
    raw_data['target_binary'] = (raw_data.target >= BINDER_THRESHOLD).astype(int)

    # Position of bars on x-axis
    ind = np.arange(train_raw.peptide.str.len().nunique())
    neg = raw_data[raw_data.target_binary == 0].peptide_length.value_counts().sort_index()
    pos = raw_data[raw_data.target_binary == 1].peptide_length.value_counts().sort_index()

    # Plotting
    plt.figure()
    width = 0.3  

    plt.bar(ind, neg, width, label='Non-binders')
    plt.bar(ind + width, pos, width, label='Binders')

    plt.xlabel('Peptide lengths')
    plt.ylabel('Count of peptides')
    plt.title('Distribution of peptide lengths in %s data' %raw_set)
    plt.xticks(ind + width / 2, ['%dmer' %i for i in neg.index])
    plt.legend(loc='best')
    plt.show()

In [None]:
plot_peptide_distribution(train_raw, 'train')

In [None]:
def plot_target_values(data=[(train_raw, 'Train set'), (valid_raw, 'Validation set'), (test_raw, 'Test set')]):
    plt.figure(figsize=(15,4))
    for partition, label in data:
        x = partition.index
        y = partition.target
        plt.scatter(x, y, label=label, marker='.')
    plt.axhline(y=BINDER_THRESHOLD, color='r', linestyle='--', label='Binder threshold')
    plt.legend(frameon=False)
    plt.title('Target values')
    plt.xlabel('Index of dependent variable')
    plt.ylabel('Dependent varible')
    plt.show()

In [None]:
plot_target_values()

### Encode data

In [None]:
x_train_, y_train_ = encode_peptides(train_raw)
x_valid_, y_valid_ = encode_peptides(valid_raw)
x_test_, y_test_ = encode_peptides(test_raw)

Check the data dimensions for the train set and validation set (batch_size, MAX_PEP_SEQ_LEN, n_features)

In [None]:
print(x_train_.shape)
print(x_valid_.shape)
print(x_test_.shape)

### Flatten tensors

In [None]:
x_train_ = x_train_.reshape(x_train_.shape[0], -1)
x_valid_ = x_valid_.reshape(x_valid_.shape[0], -1)
x_test_ = x_test_.reshape(x_test_.shape[0], -1)

In [None]:
batch_size = x_train_.shape[0]
n_features = x_train_.shape[1]

### Make data iterable

In [None]:
x_train = Variable(torch.from_numpy(x_train_.astype('float32')))
y_train = Variable(torch.from_numpy(y_train_.astype('float32'))).view(-1, 1)

x_valid = Variable(torch.from_numpy(x_valid_.astype('float32')))
y_valid = Variable(torch.from_numpy(y_valid_.astype('float32'))).view(-1, 1)

x_test = Variable(torch.from_numpy(x_test_.astype('float32')))
y_test = Variable(torch.from_numpy(y_test_.astype('float32'))).view(-1, 1)

## Build Model

In [None]:
class Net(nn.Module):

    def __init__(self, n_features, n_l1):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_features, n_l1)
        self.fc2 = nn.Linear(n_l1, 1)
        
        # Activation functions
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Select Hyper-parameters

In [None]:
def init_weights(m):
    """
    https://pytorch.org/docs/master/nn.init.html
    """
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0) # alternative command: m.bias.data.fill_(0.01)

In [None]:
EPOCHS = 3000
MINI_BATCH_SIZE = 512
N_HIDDEN_NEURONS = 16
LEARNING_RATE = 0.1
PATIENCE = EPOCHS // 10

## Compile Model

In [None]:
net = Net(n_features, N_HIDDEN_NEURONS)
#net.apply(init_weights)

net

In [None]:
count_parameters(net)

In [None]:
optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

## Train Model

In [None]:
# No mini-batch loading
# mini-batch loading
def train():
    train_loss, valid_loss = [], []

    early_stopping = EarlyStopping(patience=PATIENCE)

    for epoch in range(EPOCHS):
        net.train()
        pred = net(x_train)
        loss = criterion(pred, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.data)

        if epoch % (EPOCHS//10) == 0:
            print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, loss.data))

        net.eval()
        pred = net(x_valid)
        loss = criterion(pred, y_valid)  
        valid_loss.append(loss.data)

        if invoke(early_stopping, valid_loss[-1], net, implement=True):
            net.load_state_dict(torch.load('checkpoint.pt'))
            break
            
    return net, train_loss, valid_loss

In [None]:
# Train with mini_batches
train_loader = DataLoader(dataset=TensorDataset(x_train, y_train), batch_size=MINI_BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=TensorDataset(x_valid, y_valid), batch_size=MINI_BATCH_SIZE, shuffle=True)

def train_with_minibatches():
    
    train_loss, valid_loss = [], []

    early_stopping = EarlyStopping(patience=PATIENCE)
    for epoch in range(EPOCHS):
        batch_loss = 0
        net.train()
        for x_train, y_train in train_loader:
            pred = net(x_train)
            loss = criterion(pred, y_train)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_loss += loss.data
        train_loss.append(batch_loss / len(train_loader))

        batch_loss = 0
        net.eval()
        for x_valid, y_valid in valid_loader:
            pred = net(x_valid)
            loss = criterion(pred, y_valid)
            batch_loss += loss.data
        valid_loss.append(batch_loss / len(valid_loader))
        
        if epoch % (EPOCHS//10) == 0:
            print('Train Epoch: {}\tLoss: {:.6f}\tVal Loss: {:.6f}'.format(epoch, train_loss[-1], valid_loss[-1]))

        if invoke(early_stopping, valid_loss[-1], net, implement=True):
            net.load_state_dict(torch.load('checkpoint.pt'))
            break
            
    return net, train_loss, valid_loss

In [None]:
net, train_loss, valid_loss = train()

In [None]:
#net, train_loss, valid_loss = train_with_minibatches()

In [None]:
def plot_losses(burn_in=20):
    plt.figure(figsize=(15,4))
    plt.plot(list(range(burn_in, len(train_loss))), train_loss[burn_in:], label='Training loss')
    plt.plot(list(range(burn_in, len(valid_loss))), valid_loss[burn_in:], label='Validation loss')

    # find position of lowest validation loss
    minposs = valid_loss.index(min(valid_loss))+1 
    plt.axvline(minposs, linestyle='--', color='r',label='Minimum Validation Loss')

    plt.legend(frameon=False)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.show()
    
plot_losses()

## Evaluation

### Predict on test set

In [None]:
net.eval()
pred = net(x_test)
loss = criterion(pred, y_test)

In [None]:
plot_target_values(data=[(pd.DataFrame(pred.data.numpy(), columns=['target']), 'Prediction'),
                         (test_raw, 'Target')])

### Transform targets to class

In [None]:
y_test_class = np.where(y_test.flatten() >= BINDER_THRESHOLD, 1, 0)
y_pred_class = np.where(pred.flatten() >= BINDER_THRESHOLD, 1, 0)

### Receiver Operating Caracteristic (ROC) curve

In [None]:
def plot_roc_curve(peptide_length=[9]):
    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, label = 'AUC = %0.2f (%smer)' %(roc_auc, '-'.join([str(i) for i in peptide_length])))
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1], c='black', linestyle='--')
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')

In [None]:
# Combining targets and prediction values with peptide length in a dataframe
pred_per_len = pd.DataFrame([test_raw.peptide.str.len().to_list(),
                             y_test_class,
                             pred.flatten().detach().numpy()],
                            index=['peptide_length','target','prediction']).T

plt.figure(figsize=(7,7))
# For each peptide length compute AUC and plot ROC
for length, grp in pred_per_len.groupby('peptide_length'):
    fpr, tpr, threshold = roc_curve(grp.target, grp.prediction)
    roc_auc = auc(fpr, tpr)
    
    plot_roc_curve(peptide_length=[int(length)])

# Evaluating model on peptides of length other than 9 AA.
for lengths in [[8,10,11],[8,9,10,11]]:
    grp = pred_per_len[pred_per_len.peptide_length.isin(lengths)]
    if not grp.empty:
        fpr, tpr, threshold = roc_curve(grp.target, grp.prediction)
        roc_auc = auc(fpr, tpr)

        plot_roc_curve(peptide_length=lengths)

    else:
        print("Data does not contain peptides of length other than 9 AA.")

### Matthew's Correlation Coefficient (MCC)

In [None]:
mcc = matthews_corrcoef(y_test_class, y_pred_class)

In [None]:
def plot_mcc():
    plt.title('Matthews Correlation Coefficient')
    plt.scatter(y_test.flatten().detach().numpy(), pred.flatten().detach().numpy(), label = 'MCC = %0.2f' % mcc)
    plt.legend(loc = 'lower right')
    plt.ylabel('Predicted')
    plt.xlabel('Validation targets')
    plt.show()

plot_mcc()