# Sources
* Customised implementation of the journal article:
    * pepsickle rapidly and accurately predicts proteasomal cleavage sites for improved neoantigen identification
    * Benjamin R Weeder, Mary A Wood, Ellysia Li, Abhinav Nellore, Reid F Thompson
    * https://academic.oup.com/bioinformatics/article/37/21/3723/6363787
    * https://github.com/pdxgx/pepsickle
    * https://github.com/pdxgx/pepsickle-paper

In [1]:
import os
import sys
import csv
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import vocab

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, "r") as csvfile:
        train_data = list(csv.reader(csvfile))[1:]  # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls


def regularized_auc(train_auc, dev_auc, threshold=0.0025):
    """
    Returns development AUC if overfitting is below threshold, otherwise 0.
    """
    return dev_auc if (train_auc - dev_auc) < threshold else 0


def save_metrics(*args, path):
    if not os.path.isfile(path):
        with open(path, "w", newline="\n") as f:
            f.write(
                ",".join(
                    [
                        "fold",
                        "epoch",
                        "seq_train_loss",
                        "motif_train_loss",
                        "seq_train_acc",
                        "motif_train_acc",
                        "combined_train_acc",
                        "seq_train_auc",
                        "motif_train_auc",
                        "combined_train_auc",
                        "seq_val_loss",
                        "motif_val_loss",
                        "seq_val_acc",
                        "motif_val_acc",
                        "combined_val_acc",
                        "seq_val_auc",
                        "motif_val_auc",
                        "combined_val_auc",
                    ]
                )
            )
            f.write("\n")
    if args:
        with open(path, "a", newline="\n") as f:
            f.write(",".join([str(arg) for arg in args]))
            f.write("\n")  
            
def trainable_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def total_model_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]

    def __len__(self):
        return len(self.lbl)


class Batch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        self.seq = torch.tensor(
            [
                [_features[aa] for aa in list(seq)]
                for seq in ordered_batch[0]
            ],
            dtype=torch.float,
        )
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def collate_wrapper(batch):
    return Batch(batch)

In [5]:
# model architectures taken from 
# https://github.com/pdxgx/pepsickle-paper/blob/master/scripts/modeling/epitope_based_ensemble_net.py

class SeqNet(nn.Module):
    def __init__(self, hidden_size1, hidden_size2, hidden_size3, dropout):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # input to linear: seq_len * 20
        self.fc1 = nn.Linear(200, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, hidden_size3)
        self.fc4 = nn.Linear(hidden_size3, 1)
        
        self.bn1 = nn.BatchNorm1d(hidden_size1)
        self.bn2 = nn.BatchNorm1d(hidden_size2)
        self.bn3 = nn.BatchNorm1d(hidden_size3)
        
    def forward(self, seq):
        out = self.dropout(F.relu(self.bn1(self.fc1(seq))))
        out = self.dropout(F.relu(self.bn2(self.fc2(out))))
        out = self.dropout(F.relu(self.bn3(self.fc3(out))))
        return self.fc4(out).squeeze()
    
    
class MotifNet(nn.Module):
    def __init__(self, hidden_size1, hidden_size2, dropout):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # conv parameters are fixed due to feature assemply process
        # see dictionary variable _features
        self.conv = nn.Conv1d(
            in_channels=4,
            out_channels=4,
            kernel_size=3,
            groups=4
        )
    
        # input to linear: groups * (seq_len-2)
        self.fc1 = nn.Linear(32, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 1)
        
        self.bn1 = nn.BatchNorm1d(hidden_size1)
        self.bn2 = nn.BatchNorm1d(hidden_size2)
        
    def forward(self, seq):
        out = self.conv(seq.transpose(1, 2))
        
        out = self.dropout(F.relu(self.bn1(self.fc1(out.view(out.shape[0], -1)))))
        out = self.dropout(F.relu(self.bn2(self.fc2(out))))
        return self.fc3(out).squeeze()

In [6]:
def process(seq_model, motif_model, loader, criterion, optims=None):
    seq_epoch_loss, seq_num_correct, num_correct, total = 0, 0, 0, 0
    motif_epoch_loss, motif_num_correct = 0, 0
    seq_preds, motif_preds, preds, lbls = [], [], [], []
    
    for batch in tqdm(
        loader,
        desc="Train: " if optims is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
         
        motif_scores = motif_model(seq[:, :, 22:])
        seq_scores = seq_model(seq[:, :, :20].reshape(seq.shape[0], -1))
        scores = (motif_scores + seq_scores) / 2
            
        motif_loss = criterion(motif_scores, lbl)
        seq_loss = criterion(seq_scores, lbl)
        
        if optims is not None:
            optims[0].zero_grad()
            seq_loss.backward()
            optims[0].step()
            optims[1].zero_grad()
            motif_loss.backward()
            optims[1].step()
        
        seq_epoch_loss += seq_loss.item()
        motif_epoch_loss += motif_loss.item()
        seq_num_correct += ((seq_scores > 0) == lbl).sum().item()
        motif_num_correct += ((motif_scores > 0) == lbl).sum().item()
        num_correct += ((scores > 0) == lbl).sum().item()
        total += seq.shape[0]
        seq_preds.extend(seq_scores.detach().tolist())
        motif_preds.extend(motif_scores.detach().tolist())
        preds.extend(scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
        
    return [
        seq_epoch_loss / total,
        motif_epoch_loss / total,
        seq_num_correct / total,
        motif_num_correct / total,
        num_correct / total,
        roc_auc_score(lbls, seq_preds),
        roc_auc_score(lbls, motif_preds),
        roc_auc_score(lbls, preds)
    ]

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 512

# load train+dev set, mix it back as one
train_path = '../../data/n_train.csv'
dev_path = '../../data/n_val.csv'
test_path = '../../data/n_test.csv'

# combine previously split train and dev set
train_seqs, train_lbls = read_data(train_path)
dev_seqs, dev_lbls = read_data(dev_path)
total_seqs, total_lbls = np.array(train_seqs + dev_seqs), np.array(train_lbls + dev_lbls)

assert len(train_seqs) + len(dev_seqs) == len(total_seqs)
assert len(train_lbls) + len(dev_lbls) == len(total_lbls)

test_seqs, test_lbls = read_data(test_path)

test_data = CleavageDataset(test_seqs, test_lbls)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

In [8]:
# see https://github.com/pdxgx/pepsickle README for more info

_features = {
    'A': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,      6.0, 56.15265,   -0.495,  -2.4],
    'C': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.07, 69.61701,    0.081,  -4.7],
    'D': [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     2.77, 70.04515,    9.573,  -4.5],
    'E': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     3.22, 86.35615,    3.173,  -5.2],
    'F': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   1,   0,     5.48,  119.722,   -0.370,  -4.9],
    'G': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.97, 37.80307,    0.386,  -1.9],
    'H': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   1,   0,     7.59, 97.94236,    2.029,  -4.4],
    'I': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     6.02, 103.6644,   -0.528,  -6.6],
    'K': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     9.74, 102.7783,    2.101,  -7.5],
    'L': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.98, 102.7545,   -0.342,  -6.3],
    'M': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.74,  103.928,   -0.324,  -6.1],
    'N': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,   0,   0,     5.41, 76.56687,    2.354,  -4.7],
    'P': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,   0,   0,      6.3, 71.24858,   -0.322,  -0.8],
    'Q': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,   0,   0,     5.65, 88.62562,    2.176,  -5.5],
    'R': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,   0,   0,    10.76, 110.5867,    4.383,  -6.9],
    'S': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,   0,   1,     5.68, 55.89516,    0.936,  -4.6],
    'T': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,   0,   1,      5.6,  72.0909,    0.853,  -5.1],
    'V': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,   0,   0,     5.96, 86.28358,   -0.308,  -4.6],
    'W': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,   1,   0,     5.89, 137.5186,    -0.27,  -4.8],
    'Y': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,   1,   1,     5.66, 121.5862,    1.677,  -5.4],
}

In [9]:
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3

motif_params = {
    "hidden_size1": 200,
    "hidden_size2": 23,
    "dropout": 0.12
}

seq_params = {
    "hidden_size1": 175,
    "hidden_size2": 96,
    "hidden_size3": 44,
    "dropout": 0.32
}

criterion = nn.BCEWithLogitsLoss()

In [None]:
kf = KFold(n_splits=10, shuffle=True, random_state=1234)
path = "../../params/n_term/pepsickle/"
logging_path = path + "results.csv"

start = time()
print("Starting Cross-Validation.")
highest_val_auc = 0

# get a new split
for fold, (train_idx, dev_idx) in enumerate(kf.split(total_seqs), 1):
    X_tr = total_seqs[train_idx]
    y_tr = total_lbls[train_idx]
    X_dev = total_seqs[dev_idx]
    y_dev = total_lbls[dev_idx]

    # create datasets and loads with current split
    train_data = CleavageDataset(X_tr, y_tr)
    train_loader = DataLoader(
        train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_wrapper,
        pin_memory=True,
        num_workers=10,
    )

    dev_data = CleavageDataset(X_dev, y_dev)
    dev_loader = DataLoader(
        dev_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_wrapper,
        pin_memory=True,
        num_workers=10,
    )

    # reset model weights with each new fold
    motif_model = MotifNet(**motif_params).to(device)
    seq_model = SeqNet(**seq_params).to(device)
    motif_optimizer = optim.Adam(motif_model.parameters(), lr=1e-3)
    seq_optimizer = optim.Adam(seq_model.parameters(), lr=1e-3)

    # normal training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        seq_model.train()
        motif_model.train()
        train_results = process(
            seq_model,
            motif_model,
            train_loader,
            criterion,
            [seq_optimizer, motif_optimizer],
        )

        seq_model.eval()
        motif_model.eval()
        with torch.no_grad():
            val_results = process(seq_model, motif_model, dev_loader, criterion)

        results = train_results + val_results
        # save metrics
        save_metrics(
            fold,
            epoch,
            *results,
            path=logging_path,
        )

        if (
            regularized_auc(train_results[5], val_results[5], threshold=0)
            and regularized_auc(train_results[6], val_results[6], threshold=0)
            and val_results[7] > highest_val_auc
        ):
            highest_val_auc = val_results[7]
            seq_path = path + f"auc{val_results[7]:.4f}_fold{fold}_epoch{epoch}_seq.pt"
            motif_path = (
                path + f"auc{val_results[7]:.4f}_fold{fold}_epoch{epoch}_motif.pt"
            )
            torch.save(seq_model.state_dict(), seq_path)
            torch.save(motif_model.state_dict(), motif_path)

print("Finished Cross-Validation.")
train_time = (time() - start) / 60
print(f"Cross-Validation took {train_time} minutes.")

In [11]:
# load best model, evaluate on test set
seq_best_model = sorted(
    [f for f in os.listdir(path) if f.endswith("seq.pt")],
    reverse=True,
)[0]
motif_best_model = sorted(
    [f for f in os.listdir(path) if f.endswith("motif.pt")],
    reverse=True,
)[0]

print(f'Loaded {seq_best_model} and {motif_best_model}')

motif_model = MotifNet(**motif_params).to(device)
seq_model = SeqNet(**seq_params).to(device)
seq_model.load_state_dict(torch.load(path + seq_best_model))
motif_model.load_state_dict(torch.load(path + motif_best_model))

seq_model.eval()
motif_model.eval()
with torch.no_grad():
    test_results = process(seq_model, motif_model, test_loader, criterion)
print(
    f"Test Set Performance: Acc: {test_results[4]:.4f}, AUC: {test_results[7]:.4f}"
)
print(
    f"Seq: Total model params: {total_model_params(seq_model)}, trainable model params: {trainable_model_params(seq_model)}"
)
print(
    f"Motif: Total model params: {total_model_params(motif_model)}, trainable model params: {trainable_model_params(motif_model)}"
)

Loaded auc0.7883_fold3_epoch10_seq.pt and auc0.7883_fold3_epoch10_motif.pt
Eval: 100%|██████████████████████████████████████████████████████| 281/281 [00:00<00:00, 321.81batches/s]
Test Set Performance: Acc: 0.8318, AUC: 0.7888
Seq: Total model params: 57014, trainable model params: 57014
Motif: Total model params: 11709, trainable model params: 11709
