# Sources
* Customised implementation of the journal article:
    * pepsickle rapidly and accurately predicts proteasomal cleavage sites for improved neoantigen identification
    * Benjamin R Weeder, Mary A Wood, Ellysia Li, Abhinav Nellore, Reid F Thompson
    * https://academic.oup.com/bioinformatics/article/37/21/3723/6363787
    * https://github.com/pdxgx/pepsickle
    * https://github.com/pdxgx/pepsickle-paper

In [1]:
import os
import csv
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import roc_auc_score

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

def regularized_auc(train_auc, dev_auc, threshold=0.0025):
    """
    Returns development AUC if overfitting is below threshold, otherwise 0.
    """
    return dev_auc if (train_auc - dev_auc) < threshold else 0

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
def collate_batch(batch):
    ordered_batch = list(zip(*batch))
    seq = torch.tensor(
        [
            [_features[aa] for aa in list(seq)]
            for seq in ordered_batch[0]
        ],
        dtype=torch.float,
    )
    lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
    return seq, lbl

In [5]:
# model architectures taken from 
# https://github.com/pdxgx/pepsickle-paper/blob/master/scripts/modeling/epitope_based_ensemble_net.py

class SeqNet(nn.Module):
    def __init__(self, hidden_size1, hidden_size2, hidden_size3, dropout):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # input to linear: seq_len * 20
        self.fc1 = nn.Linear(200, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, hidden_size3)
        self.fc4 = nn.Linear(hidden_size3, 1)
        
        self.bn1 = nn.BatchNorm1d(hidden_size1)
        self.bn2 = nn.BatchNorm1d(hidden_size2)
        self.bn3 = nn.BatchNorm1d(hidden_size3)
        
    def forward(self, seq):
        out = self.dropout(F.relu(self.bn1(self.fc1(seq))))
        out = self.dropout(F.relu(self.bn2(self.fc2(out))))
        out = self.dropout(F.relu(self.bn3(self.fc3(out))))
        return self.fc4(out).squeeze()
    
    
class MotifNet(nn.Module):
    def __init__(self, hidden_size1, hidden_size2, dropout):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # conv parameters are fixed due to feature assemply process
        # see dictionary variable _features
        self.conv = nn.Conv1d(
            in_channels=4,
            out_channels=4,
            kernel_size=3,
            groups=4
        )
    
        # input to linear: groups * (seq_len-2)
        self.fc1 = nn.Linear(32, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 1)
        
        self.bn1 = nn.BatchNorm1d(hidden_size1)
        self.bn2 = nn.BatchNorm1d(hidden_size2)
        
    def forward(self, seq):
        # input shape: (batch_size, seq_len, num_features)
        out = self.conv(seq.transpose(1, 2))
        
        # input shape: (batch_size, groups, seq_len-2)
        out = self.dropout(F.relu(self.bn1(self.fc1(out.view(out.shape[0], -1)))))
        out = self.dropout(F.relu(self.bn2(self.fc2(out))))
        return self.fc3(out).squeeze()

In [6]:
def process(seq_model, motif_model, loader, criterion, optims=None):
    seq_epoch_loss, seq_num_correct, total = 0, 0, 0
    motif_epoch_loss, motif_num_correct = 0, 0
    seq_preds, motif_preds, lbls = [], [], []
    
    for seq, lbl in loader:
        seq, lbl = seq.to(device), lbl.to(device)
         
        motif_scores = motif_model(seq[:, :, 22:])
        seq_scores = seq_model(seq[:, :, :20].reshape(seq.shape[0], -1))
            
        motif_loss = criterion(motif_scores, lbl)
        seq_loss = criterion(seq_scores, lbl)
        
        if optims is not None:
            optims[0].zero_grad()
            seq_loss.backward()
            optims[0].step()
            optims[1].zero_grad()
            motif_loss.backward()
            optims[1].step()
        
        seq_epoch_loss += seq_loss.item()
        motif_epoch_loss += motif_loss.item()
        seq_num_correct += ((seq_scores > 0) == lbl).sum().item()
        motif_num_correct += ((motif_scores > 0) == lbl).sum().item()
        total += seq.shape[0]
        seq_preds.extend(seq_scores.detach().tolist())
        motif_preds.extend(motif_scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
        
    return (
        seq_epoch_loss / total,
        motif_epoch_loss / total,
        seq_num_correct / total,
        motif_num_correct / total,
        roc_auc_score(lbls, seq_preds),
        roc_auc_score(lbls, motif_preds)
    )

In [7]:
# see https://github.com/pdxgx/pepsickle README for more info

_features = {
    'A': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,      6.0, 56.15265,   -0.495,  -2.4],
    'C': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.07, 69.61701,    0.081,  -4.7],
    'D': [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     2.77, 70.04515,    9.573,  -4.5],
    'E': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     3.22, 86.35615,    3.173,  -5.2],
    'F': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   1,   0,     5.48,  119.722,   -0.370,  -4.9],
    'G': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.97, 37.80307,    0.386,  -1.9],
    'H': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   1,   0,     7.59, 97.94236,    2.029,  -4.4],
    'I': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     6.02, 103.6644,   -0.528,  -6.6],
    'K': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     9.74, 102.7783,    2.101,  -7.5],
    'L': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.98, 102.7545,   -0.342,  -6.3],
    'M': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.74,  103.928,   -0.324,  -6.1],
    'N': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.41, 76.56687,    2.354,  -4.7],
    'P': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,   0,   0,      6.3, 71.24858,   -0.322,  -0.8],
    'Q': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,   0,   0,     5.65, 88.62562,    2.176,  -5.5],
    'R': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,   0,   0,    10.76, 110.5867,    4.383,  -6.9],
    'S': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,   0,   1,     5.68, 55.89516,    0.936,  -4.6],
    'T': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,   0,   1,      5.6,  72.0909,    0.853,  -5.1],
    'V': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,   0,   0,     5.96, 86.28358,   -0.308,  -4.6],
    'W': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,   1,   0,     5.89, 137.5186,    -0.27,  -4.8],
    'Y': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,   1,   1,     5.66, 121.5862,    1.677,  -5.4],
}

In [8]:
# load train and dev data
train_seqs, train_lbl = read_data('../../data/n_train.csv')
dev_seqs, dev_lbl = read_data('../../data/n_val.csv')

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(
    config, checkpoint_dir=None, tr_seqs=None, tr_lbls=None, val_seqs=None, val_lbls=None
):
    
    # create train and dev loader
    train_data = CleavageDataset(tr_seqs, tr_lbls)
    train_loader = DataLoader(train_data, batch_size=512, shuffle=True, collate_fn=collate_batch, num_workers=8)

    dev_data = CleavageDataset(val_seqs, val_lbls)
    dev_loader = DataLoader(dev_data, batch_size=512, shuffle=True, collate_fn=collate_batch, num_workers=8)
    
    
    seq_model = SeqNet(
        hidden_size1=config['seq_hidden1'],
        hidden_size2=config['seq_hidden2'],
        hidden_size3=config['seq_hidden3'],
        dropout=config['seq_dropout']
    ).to(device)
    
    motif_model = MotifNet(
        hidden_size1=config['motif_hidden1'],
        hidden_size2=config['motif_hidden2'],
        dropout=config['motif_dropout']
    ).to(device)
    
    seq_optimizer = optim.Adam(seq_model.parameters(), lr=1e-3)
    motif_optimizer = optim.Adam(motif_model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss()
    
    # normal train loop
    for epoch in range(1, 100 + 1):
        seq_model.train()
        motif_model.train()
        train_results = process(
            seq_model, motif_model, train_loader, criterion, [seq_optimizer, motif_optimizer]
        )

        seq_model.eval()
        motif_model.eval()
        with torch.no_grad():
            val_results = process(seq_model, motif_model, dev_loader, criterion)
        
        with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
            seq_checkpoint_path = os.path.join(checkpoint_dir, "seq_checkpoint.pt")
            motif_checkpoint_path = os.path.join(checkpoint_dir, "motif_checkpoint.pt")
            torch.save(seq_model.state_dict(), seq_checkpoint_path)
            torch.save(motif_model.state_dict(), motif_checkpoint_path)
            
        # metrics that will be reported back to main process
        seq_reg_auc = regularized_auc(train_results[4], val_results[4])
        motif_reg_auc = regularized_auc(train_results[5], val_results[5])
        
        tune.report(
            seq_train_loss=train_results[0],
            seq_val_loss=val_results[0],
            motif_train_loss=train_results[1],
            motif_val_loss=val_results[1],
            seq_train_acc=train_results[2],
            seq_val_acc=val_results[2],
            motif_train_acc=train_results[3],
            motif_val_acc=val_results[3],
            seq_train_auc=train_results[4],
            seq_val_auc=val_results[4],
            motif_train_auc=train_results[5],
            motif_val_auc=val_results[5],
            seq_reg_auc=seq_reg_auc,
            motif_reg_auc=motif_reg_auc,
            summed_reg_auc=seq_reg_auc+motif_reg_auc
        )
        
        # if both models are overfitting, stop the run
        # otherwise continue the run to find optimal param combination afterwards
        if seq_reg_auc == 0 and motif_reg_auc == 0:
            break

In [10]:
class TuneReporter(CLIReporter):
    def __init__(self):
        super().__init__()
        self.num_terminated = 0

    def should_report(self, trials, done=False):
        """Reports only on trial termination events."""
        old_num_terminated = self.num_terminated
        self.num_terminated = len([t for t in trials if t.status == "TERMINATED"])
        return self.num_terminated > old_num_terminated
    
    def report(self, trials, done, *sys_info):
        print(self._progress_str(trials, done, *sys_info))
    
metrics = [
    'seq_train_loss', 'motif_train_loss',
    'seq_val_loss', 'motif_val_loss',
    'seq_train_acc', 'motif_train_acc',
    'seq_val_acc', 'motif_val_acc',
    'seq_train_auc', 'motif_train_auc',
    'seq_val_auc', 'motif_val_auc',
    'seq_reg_auc', 'motif_reg_auc',
    'summed_reg_auc',
]

reporter = TuneReporter()
for metric in metrics:
    reporter.add_metric_column(metric=metric)

In [11]:
search_space = {
    'seq_hidden1': tune.randint(120, 201),
    'seq_hidden2': tune.randint(60, 141),
    'seq_hidden3': tune.randint(20, 81),
    'seq_dropout': tune.quniform(0.2, 0.36, 0.02),
    'motif_hidden1': tune.randint(120, 201),
    'motif_hidden2': tune.randint(20, 101),
    'motif_dropout': tune.quniform(0.1, 0.26, 0.02),
}

In [None]:
path = '../../params/n_term/pepsickle/'
experiment = 'search'
num_samples = 30

analysis = tune.run(
    tune.with_parameters(
        train, tr_seqs=train_seqs, tr_lbls=train_lbl, val_seqs=dev_seqs, val_lbls=dev_lbl
    ),
    name=experiment,
    config=search_space,
    sync_config=tune.SyncConfig(syncer=None),
    num_samples=num_samples,
    scheduler=ASHAScheduler(
        metric='summed_reg_auc',
        mode='max',
        reduction_factor=2,
        grace_period=4
    ),
    progress_reporter=reporter,
    local_dir=path,
    keep_checkpoints_num=None, # keep all checkpoints
    checkpoint_score_attr='summed_reg_auc',
    resume='AUTO',
    resources_per_trial={'cpu': 16, 'gpu': 1},
)