### Model architecture is based on https://www.biorxiv.org/content/10.1101/710699v2

In [1]:
import os
import sys
import csv
import math
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.parametrizations import spectral_norm

from torchtext.vocab import build_vocab_from_iterator

from sklearn.metrics import roc_auc_score

In [3]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [4]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

# number of trainable parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
def gelu(x):
    """
    Facebook Research implementation of the gelu activation function.
    
    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [6]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
class CleavageBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        self.seq = F.one_hot(
            torch.tensor([encode_text(seq) for seq in ordered_batch[0]]), num_classes=len(vocab)
        ).to(torch.float)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
        
    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def collate_wrapper(batch):
    return CleavageBatch(batch)

In [16]:
class MLP(nn.Module):
    def __init__(self, vocab_size, seq_len, hidden_size1, hidden_size2, rnn_size, dropout):
        super().__init__()
        
        self.fc1 = nn.Linear(vocab_size * seq_len, hidden_size1)
        
        self.lstm = nn.LSTM(
            input_size=hidden_size1,
            hidden_size=rnn_size,
            bidirectional=True,
            batch_first=True
        )
        
        self.fc2 = nn.Linear(rnn_size*2, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 1)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, seq):
        # input shape: (batch_size, seq_len*vocab_size)
        out = self.dropout(F.relu(self.fc1(seq)))
        
        out, _ = self.lstm(out.unsqueeze(1))
        
        out = self.dropout(F.relu(self.fc2(out.squeeze())))
        
        return self.fc3(out).squeeze()

In [8]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    preds, lbls = [], []
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores = model(seq.view(seq.shape[0], -1))
        loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum().item()
        total += len(seq)
        preds.extend(scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
    return epoch_loss / total, num_correct / total, roc_auc_score(lbls, preds)

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load train and dev data
train_seqs, train_lbl = read_data('../../data/n_train.csv')
dev_seqs, dev_lbl = read_data('../../data/n_val.csv')

# create vocab from train seqs
vocab = build_vocab_from_iterator(train_seqs, specials=['<UNK>'])
# vocab = build_vocab_from_iterator(train_seqs)
vocab.set_default_index(vocab['<UNK>'])
encode_text = lambda x: vocab(list(x))

In [20]:
NUM_EPOCHS = 25
BATCH_SIZE = 512
VOCAB_SIZE = len(vocab)
SEQ_LEN = 10
HIDDEN_SIZE1 = 128
HIDDEN_SIZE2 = 128
RNN_SIZE = 512
DROPOUT = 0.5
LEARNING_RATE = 1e-4

model = MLP(
    vocab_size=VOCAB_SIZE,
    seq_len=SEQ_LEN,
    hidden_size1=HIDDEN_SIZE1,
    hidden_size2=HIDDEN_SIZE2,
    rnn_size=RNN_SIZE,
    dropout=DROPOUT
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_loader = DataLoader(dev_data, batch_size = BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

print(f"Total trainable model parameters: {get_total_model_params(model):,}")

Total trainable model parameters: 2,787,969


In [21]:
start = time()
print("Starting Training.")
highest_val_auc = 0
train_losses, train_accuracies, train_aucs = [], [], []
val_losses, val_accuracies, val_aucs = [], [], []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, train_acc, train_auc = process(model, train_loader, criterion, optimizer)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_acc, val_auc = process(model, dev_loader, criterion)
        
#     # save current acc, loss
#     train_losses.append((epoch, train_loss))
#     train_accuracies.append((epoch, train_acc))
#     train_aucs.append((epoch, train_auc))
#     val_losses.append((epoch, val_loss))
#     val_accuracies.append((epoch, val_acc))
#     val_aucs.append((epoch, val_auc))
    
#     if val_auc > highest_val_auc:
#         highest_val_auc = val_auc
#         path = f"../../params/n_term/mlp/auc{val_auc:.4f}_epoch{epoch}.pt"
#         torch.save(model.state_dict(), path)
        
    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}, AUC: {train_auc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]")
    
print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train: 100%|███████████████████████████████████████████████████| 2236/2236 [00:05<00:00, 414.65batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:00<00:00, 501.66batches/s]
Training:   [Epoch  1, Loss:   0.0012, Acc: 0.6647, AUC: 0.7273]
Evaluation: [Epoch  1, Loss:   0.0011, Acc: 0.6866, AUC: 0.7561]
Train: 100%|███████████████████████████████████████████████████| 2236/2236 [00:05<00:00, 436.85batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:00<00:00, 501.18batches/s]
Training:   [Epoch  2, Loss:   0.0011, Acc: 0.6845, AUC: 0.7510]
Evaluation: [Epoch  2, Loss:   0.0011, Acc: 0.6875, AUC: 0.7581]
Train: 100%|███████████████████████████████████████████████████| 2236/2236 [00:05<00:00, 445.25batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:00<00:00, 511.98batches/s]
Training:   [Epoch  3, Loss:   0.0011, Acc: 0.6872, AUC: 0.7544]
Evaluation: [Epoch  

In [None]:
# save training stats
lsts = [train_losses, train_accuracies, val_losses, val_accuracies, train_aucs, val_aucs, train_time]
names = [
    "train_losses",
    "train_accuracies",
    "val_losses",
    "val_accuracies",
    "train_aucs",
    "val_aucs",
    "train_time",
]
to_save = dict()
for name, lst in zip(names, lsts):
    to_save[name] = lst

with open(f"../params/n_term/mlp/metrics.pkl", "wb") as f:
    pickle.dump(to_save, f, pickle.HIGHEST_PROTOCOL)

print("Finished Saving Details.")