* https://github.com/jowoojun/biovec/blob/master/word2vec/models.py
* https://github.com/ehsanasgari/Deep-Proteomics

In [1]:
import os
import sys
import csv
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, 'r') as f:
        seqs, lbls = [], []
        for l in f.readlines()[1:]:
            seq, lbl = l.strip().split('\t')
            seqs.append(seq)
            lbls.append(lbl)
    return seqs, lbls

def read_embeddings(path):
    with open(path, 'r') as f:
        seq, vec = [], []
        for line in f.readlines()[2:]: # skip first 2 special chars
            lst = line.split()
            seq.append(lst[0].upper())
            vec.append([float(i) for i in lst[1:]])
        vocab = {s: i for i, s in enumerate(seq[1:], 1)}
        prot2vec = torch.tensor(vec, dtype=torch.float)
    return vocab, prot2vec

# encodes kmer sequence
# automatically returns 0 for unknown kmer
encode_text = lambda seq: [vocab.get(s) for s in seq.split()]

# number of trainable parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
class CleavageBatch:
    def __init__(self, batch: List[Tuple[str, str]]):
        ordered_batch = list(zip(*batch))
        self.seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
        
    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def collate_wrapper(batch):
    return CleavageBatch(batch)

In [5]:
class BiLSTM(nn.Module):
    def __init__(self, pretrained_embeds, rnn_size, hidden_size, dropout):
        super().__init__()
        
        embeding_dim = pretrained_embeds.shape[1]
        
        self.embedding = nn.Embedding.from_pretrained(
            embeddings=pretrained_embeds,
            freeze=True # will be unfrozen in epoch 20
        )
        
        self.dropout=nn.Dropout(dropout)
        
        self.lstm = nn.LSTM(
            input_size=embeding_dim,
            hidden_size=rnn_size,
            bidirectional=True,
            batch_first=True,
        )
        
        self.fc1 = nn.Linear(rnn_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, seq):
        # input shape: (batch_size, seq_len=10)
        embedded = self.dropout(self.embedding(seq))
        
        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm(embedded)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled = torch.mean(out, dim=1)
        
        # input shape: (batch_size, 2*hidden_size)
        out = self.dropout(F.relu(self.fc1(pooled)))
        
        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size)
        out = self.fc2(out).squeeze()
        return out 

In [6]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores = model(seq)
        loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum()
        total += len(seq)
    return epoch_loss / total, num_correct / total

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load train and dev data
train_seqs, train_lbl = read_data('../../data/n_train_3mer.tsv')
dev_seqs, dev_lbl = read_data('../../data/n_val_3mer.tsv')

# load vocab and embeddings
vocab, prot2vec = read_embeddings('../../params/uniref_3M/uniref_3M.vec')

In [8]:
NUM_EPOCHS = 30
BATCH_SIZE = 512
RNN_SIZE = 768
HIDDEN_SIZE = 128
DROPOUT = 0.5
LEARNING_RATE = 1e-4

model = BiLSTM(
    pretrained_embeds=prot2vec,
    rnn_size=RNN_SIZE,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_loader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

print(f"Total trainable model parameters: {get_total_model_params(model):,}")

Total trainable model parameters: 11,071,745


In [9]:
start = time()
print("Starting Training.")
highest_val_acc = 0
train_losses, train_accuracies= [], []
val_losses, val_accuracies = [], []

for epoch in range(1, NUM_EPOCHS + 1):
    
    if epoch == 20:
        model.embedding.requires_grad_()
        print(f'Embeddings unfrozen in epoch {epoch}.')
        
    model.train()
    train_loss, train_acc = process(model, train_loader, criterion, optimizer)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_acc = process(model, dev_loader, criterion)
        
    # save current acc, loss
    train_losses.append((epoch, train_loss))
    train_accuracies.append((epoch, train_acc))
    val_losses.append((epoch, val_loss))
    val_accuracies.append((epoch, val_acc))
    
    if val_acc > highest_val_acc:
        highest_val_acc = val_acc
        path = f"../../params/n_term/quadBiLSTM/acc{val_acc:.4f}_epoch{epoch}.pt"
        torch.save(model.state_dict(), path)
        
    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}]")
    
print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train: 100%|████████████████████████████████████████████████████| 2236/2236 [00:26<00:00, 83.02batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:01<00:00, 221.69batches/s]
Training:   [Epoch  1, Loss:   0.0013, Acc: 0.6030]
Evaluation: [Epoch  1, Loss:   0.0012, Acc: 0.6349]
Train: 100%|████████████████████████████████████████████████████| 2236/2236 [00:27<00:00, 82.78batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:01<00:00, 225.99batches/s]
Training:   [Epoch  2, Loss:   0.0012, Acc: 0.6291]
Evaluation: [Epoch  2, Loss:   0.0012, Acc: 0.6561]
Train: 100%|████████████████████████████████████████████████████| 2236/2236 [00:27<00:00, 82.53batches/s]
Eval: 100%|██████████████████████████████████████████████████████| 280/280 [00:01<00:00, 225.90batches/s]
Training:   [Epoch  3, Loss:   0.0012, Acc: 0.6497]
Evaluation: [Epoch  3, Loss:   0.0012, Acc: 0.6727]
Train: 100%|█████████████████████

In [None]:
# save training stats
lsts = [train_losses, train_accuracies, val_losses, val_accuracies, train_time]
names = [
    "train_losses",
    "train_accuracies",
    "val_losses",
    "val_accuracies",
    "train_time",
]
to_save = dict()
for name, lst in zip(names, lsts):
    to_save[name] = lst

with open(f"../params/n_term/quadBiLSTM/metrics.pkl", "wb") as f:
    pickle.dump(to_save, f, pickle.HIGHEST_PROTOCOL)

print("Finished Saving Details.")