# Sources
* BiLSTM model architecture based on [Ozols et. al., 2021](https://www.mdpi.com/1422-0067/22/6/3071/htm)
* Prot2Vec embeddings based on [Asgari and Mofrad, 2015](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0141287), available on [Github](https://github.com/ehsanasgari/Deep-Proteomics)
* JoCoR loss function and training process adaptations are based on [Wei et al., 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Wei_Combating_Noisy_Labels_by_Agreement_A_Joint_Training_Method_with_CVPR_2020_paper.html), and official implementation on [Github](https://github.com/hongxin001/JoCoR)


In [1]:
import os
import sys
import csv
import random
import math
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchtext.vocab import build_vocab_from_iterator

from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, 'r') as f:
        seqs, lbls = [], []
        for l in f.readlines()[1:]:
            seq, lbl = l.strip().split('\t')
            seqs.append(seq)
            lbls.append(lbl)
    return seqs, lbls

def read_embeddings(path):
    with open(path, 'r') as f:
        seq, vec = [], []
        for line in f.readlines()[2:]: # skip first special chars
            lst = line.split()
            seq.append(lst[0].upper())
            vec.append([float(i) for i in lst[1:]])
        vocab = {s: i for i, s in enumerate(seq)}
        prot2vec = torch.tensor(vec, dtype=torch.float)
    return vocab, prot2vec


def apply_random_masking(seq, num_tokens):
    """
    Mask `num_tokens` as 0 at random positions per sequence.
    """
    dist = torch.rand(seq.shape)
    m, _ = torch.topk(dist, num_tokens)
    return seq * (dist < m)


def regularized_auc(train_auc, dev_auc, threshold=0.0025):
    """
    Returns development AUC if overfitting is below threshold, otherwise 0.
    """
    return dev_auc if (train_auc - dev_auc) < threshold else 0


def gelu(x):
    """
    Facebook Research implementation of the gelu activation function.
    
    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def trainable_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def total_model_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
class TrainBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
        self.seq = apply_random_masking(seq, num_tokens=1)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
        
    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def train_wrapper(batch):
    return TrainBatch(batch)


class EvalBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        self.seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
        
    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def eval_wrapper(batch):
    return EvalBatch(batch)

In [5]:
class BiLSTMProt2Vec(nn.Module):
    def __init__(self, pretrained_embeds, rnn_size, hidden_size, dropout):
        super().__init__()
        
        embeding_dim = pretrained_embeds.shape[1]
        
        self.embedding = nn.Embedding.from_pretrained(
            embeddings=pretrained_embeds,
            freeze=True
        )
        
        self.dropout=nn.Dropout(dropout)
        
        self.lstm = nn.LSTM(
            input_size=embeding_dim,
            hidden_size=rnn_size,
            bidirectional=True,
            batch_first=True,
        )
        
        self.fc1 = nn.Linear(rnn_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, seq):
        # input shape: (batch_size, seq_len=10)
        embedded = self.dropout(self.embedding(seq))
        
        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm(embedded)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled, _ = torch.max(out, dim=1)
        
        # input shape: (batch_size, 2*hidden_size)
        out = self.dropout(gelu(self.fc1(pooled)))
        
        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size)
        return self.fc2(out).squeeze()

In [6]:
def kl_loss_compute(pred, soft_targets):
    # adjusted for binary case
    kl = F.kl_div(F.logsigmoid(pred), torch.sigmoid(soft_targets), reduction='none')
    return torch.sum(kl)


class JoCoRLoss:
    def __call__(self, y1, y2, lbls, forget_rate, loss_func, kl_loss, co_lambda=0.1):
        loss_pick_1 = loss_func(y1, lbls) * (1 - co_lambda)
        loss_pick_2 = loss_func(y2, lbls) * (1 - co_lambda)
        loss_pick = (
            loss_pick_1
            + loss_pick_2
            + co_lambda * kl_loss_compute(y1, y2)
            + co_lambda * kl_loss_compute(y2, y1)
        ).cpu()

        ind_sorted = np.argsort(loss_pick.data)
        loss_sorted = loss_pick[ind_sorted]

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_sorted))

        ind_update = ind_sorted[:num_remember]

        loss = torch.mean(loss_pick[ind_update])

        return loss, loss

In [7]:
def train(model1, model2, optim, loss_func, loader, forget_rate, BCEWLL, kl_loss):
    epoch_loss1, num_correct1, total = 0, 0, 0
    epoch_loss2, num_correct2 = 0, 0
    preds1, preds2, lbls = [], [], []
    
    for batch in tqdm(loader, desc="Train: ", file=sys.stdout, unit="batches"):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores1 = model1(seq)
        scores2 = model2(seq)

        # JoCoR loss
        loss1, loss2 = loss_func(scores1, scores2, lbl, forget_rate, BCEWLL, kl_loss)
        
        optim.zero_grad()
        loss1.backward()
        optim.step()
        
        epoch_loss1 += loss1.item()
        epoch_loss2 += loss2.item()
        num_correct1 += ((scores1 > 0) == lbl).sum().item()
        num_correct2 += ((scores2 > 0) == lbl).sum().item()
        total += seq.shape[0]
        preds1.extend(scores1.detach().tolist())
        preds2.extend(scores2.detach().tolist())
        lbls.extend(lbl.detach().tolist())
        
    return (
        epoch_loss1 / total,
        epoch_loss2 / total,
        num_correct1 / total,
        num_correct2 / total,
        roc_auc_score(lbls, preds1),
        roc_auc_score(lbls, preds2)
    )

In [8]:
def evaluate(model1, model2, loader):
    num_correct1, num_correct2, total = 0, 0, 0
    preds1, preds2, lbls = [], [], []
    
    for batch in tqdm(loader, desc="Eval: ", file=sys.stdout, unit="batches"):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores1 = model1(seq)
        scores2 = model2(seq)
        
        num_correct1 += ((scores1 > 0) == lbl).sum().item()
        num_correct2 += ((scores2 > 0) == lbl).sum().item()
        total += seq.shape[0]
        preds1.extend(scores1.detach().tolist())
        preds2.extend(scores2.detach().tolist())
        lbls.extend(lbl.detach().tolist())
        
    return (
        num_correct1 / total,
        num_correct2 / total,
        roc_auc_score(lbls, preds1),
        roc_auc_score(lbls, preds2)
    )

In [9]:
def test(model, loader):
    num_correct, total = 0, 0
    preds, lbls = [], []
    
    for batch in tqdm(loader, desc="Eval: ", file=sys.stdout, unit="batches"):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores = model(seq)
        
        num_correct += ((scores > 0) == lbl).sum().item()
        total += seq.shape[0]
        preds.extend(scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
        
    return num_correct / total, roc_auc_score(lbls, preds)

In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load train and dev data
train_seqs, train_lbl = read_data('../../data/c_train_3mer.tsv')
dev_seqs, dev_lbl = read_data('../../data/c_val_3mer.tsv')

# create vocab from train seqs
# load vocab and embeddings
vocab, prot2vec = read_embeddings('../../params/uniref_3M/uniref_3M.vec')

# encodes kmer sequence
encode_text = lambda seq: [vocab.get(s, 0) for s in seq.split()]

In [11]:
NUM_EPOCHS = 15
BATCH_SIZE = 512
VOCAB_SIZE = len(vocab)
RNN_SIZE = 480
HIDDEN_SIZE = 145
DROPOUT = 0.5
LEARNING_RATE = 3e-4

NUM_GRADUAL = 10 # how many epochs for linear drop rate
NOISY_RATE = 0.2 # assumed
FORGET_RATE = NOISY_RATE / 2 # recommendation when facing asymmetric loss
EXPONENT = 1

# define drop rate schedule
rate_schedule = np.ones(NUM_EPOCHS)*FORGET_RATE
rate_schedule[:NUM_GRADUAL] = np.linspace(0, FORGET_RATE**EXPONENT, NUM_GRADUAL)

params = {
    "pretrained_embeds": prot2vec,
    "rnn_size": RNN_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "dropout": DROPOUT
}

model1 = BiLSTMProt2Vec(**params).to(device)
model2 = BiLSTMProt2Vec(**params).to(device)

BCEWLL = nn.BCEWithLogitsLoss(reduction='none')
criterion = JoCoRLoss()
optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()), lr=LEARNING_RATE)

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle=True, collate_fn=train_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_loader = DataLoader(dev_data, batch_size = BATCH_SIZE, shuffle=True, collate_fn=eval_wrapper, pin_memory=True, num_workers=10)

In [12]:
start = time()
print("Starting Training.")
highest_val_auc = 0

# normal training loop
for epoch in range(1, NUM_EPOCHS + 1):
    model1.train()
    model2.train()
    train_loss1, train_loss2, train_acc1, train_acc2, train_auc1, train_auc2 = train(
        model1=model1,
        model2=model2,
        optim=optimizer,
        loss_func=criterion, # JoCoRLoss
        loader=train_loader,
        forget_rate=rate_schedule[epoch-1],
        BCEWLL=BCEWLL, # nn.BCEWithLogitsLoss with reduction=none
        kl_loss=kl_loss_compute
    )

    model1.eval()
    model2.eval()
    with torch.no_grad():
        val_acc1, val_acc2, val_auc1, val_auc2 = evaluate(
            model1=model1,
            model2=model2,
            loader=dev_loader,
        )
    
    print(
        f"Model 1 Training:   [Epoch {epoch:2d}, Loss: {train_loss1:8.6f}, Acc: {train_acc1:.4f}, AUC: {train_auc1:.4f}]"
    )
    print(
        f"Model 1 Evaluation: [Epoch {epoch:2d}, Acc: {val_acc1:.4f}, AUC: {val_auc1:.4f}]"
    )
    print(
        f"Model 2 Training:   [Epoch {epoch:2d}, Loss: {train_loss2:8.6f}, Acc: {train_acc2:.4f}, AUC: {train_auc2:.4f}]"
    )
    print(
        f"Model 2 Evaluation: [Epoch {epoch:2d}, Acc: {val_acc2:.4f}, AUC: {val_auc2:.4f}]"
    )

    if val_auc1 > val_auc2:
        reg_auc = regularized_auc(train_auc1, val_auc1, threshold=0)
        model = model1.state_dict()
        print('saved model1')
    else:
        reg_auc = regularized_auc(train_auc2, val_auc2, threshold=0)
        model = model2.state_dict()
        print('saved model2')
    if reg_auc > highest_val_auc:
        highest_val_auc = reg_auc
        path = f"../../params/c_term/BiLSTM_prot2vec_jocor/auc{reg_auc:.4f}_epoch{epoch}.pt"
        torch.save(model, path)

print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train: 100%|████████████████████████████████████████████████████| 2218/2218 [00:30<00:00, 71.66batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 278/278 [00:01<00:00, 201.15batches/s]
Model 1 Training:   [Epoch  1, Loss: 0.001895, Acc: 0.8222, AUC: 0.5214]
Model 1 Evaluation: [Epoch  1, Acc: 0.8215, AUC: 0.6479]
Model 2 Training:   [Epoch  1, Loss: 0.001895, Acc: 0.8225, AUC: 0.5202]
Model 2 Evaluation: [Epoch  1, Acc: 0.8215, AUC: 0.6508]
saved model2
Train: 100%|████████████████████████████████████████████████████| 2218/2218 [00:30<00:00, 71.76batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 278/278 [00:01<00:00, 201.36batches/s]
Model 1 Training:   [Epoch  2, Loss: 0.001757, Acc: 0.8231, AUC: 0.6110]
Model 1 Evaluation: [Epoch  2, Acc: 0.8215, AUC: 0.7223]
Model 2 Training:   [Epoch  2, Loss: 0.001757, Acc: 0.8231, AUC: 0.6061]
Model 2 Evaluation: [Epoch  2, Acc: 0.8215, AUC: 0.7225]
saved model2
Train: 100%

In [13]:
test_path = '../../data/c_test_3mer.tsv'
test_seqs, test_lbls = read_data(test_path)

test_data = CleavageDataset(test_seqs, test_lbls)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=eval_wrapper, pin_memory=True, num_workers=10)

# load best model, evaluate on test set
best_model = sorted(
    [f for f in os.listdir("../../params/c_term/BiLSTM_prot2vec_jocor/") if f.endswith(".pt")],
    reverse=True,
)[0]
print("Loaded model: ", best_model)
model1.load_state_dict(torch.load('../../params/c_term/BiLSTM_prot2vec_jocor/' + best_model))
model1.eval()
with torch.no_grad():
    test_acc, test_auc = test(model1, test_loader)
print(
    f"Test Set Performance: Acc: {test_acc:.4f}, AUC: {test_auc:.4f}"
)
print(
    f"Total model params: {total_model_params(model1)}, trainable model params: {trainable_model_params(model1)}"
)

Loaded model:  auc0.7633_epoch5.pt
Eval: 100%|██████████████████████████████████████████████████████| 278/278 [00:00<00:00, 333.69batches/s]
Test Set Performance: Acc: 0.8220, AUC: 0.7580
Total model params: 16009371, trainable model params: 5830371
