In [1]:
import os
import sys
import csv
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler, BatchSampler

from transformers import T5Tokenizer, T5EncoderModel

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

# number of trainable parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_total_trainable_model_params(model):
    return sum(p[1].numel() for p in model.named_parameters() if p[1].requires_grad and not p[0].startswith('t5'))

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]

    def __len__(self):
        return len(self.lbl)


class BucketSampler(Sampler):
    def __init__(self, seqs, batch_size):

        # pair each sequence with their *tokenized* length
        indices = [
            (idx, len(tokenizer.encode(s.replace("", " ").strip())))
            for idx, s in enumerate(seqs)
        ]
        random.shuffle(indices)

        idx_pools = []
        # generate pseudo-random batches of (arbitrary) size batch_size * 100
        # each batch of size batch_size * 100 is sorted in itself by seq length
        for i in range(0, len(indices), batch_size * 100):
            idx_pools.extend(
                sorted(indices[i : i + batch_size * 100], key=lambda x: x[1])
            )

        # filter only indices
        self.idx_pools = [x[0] for x in idx_pools]

    def __iter__(self):
        return iter(self.idx_pools)

    def __len__(self):
        return len(self.idx_pools)


class CleavageBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        encoded = tokenizer.batch_encode_plus(
            [seq.replace("", " ").strip() for seq in ordered_batch[0]]
        )
        self.seq = torch.tensor(encoded["input_ids"], dtype=torch.int64)
        self.att = torch.tensor(encoded["attention_mask"], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.att = self.att.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def collate_wrapper(batch):
    return CleavageBatch(batch)

In [5]:
class T5_BiLSTM(nn.Module):
    def __init__(self, rnn_size, hidden_size, dropout):
        super().__init__()

        self.t5_encoder = T5EncoderModel.from_pretrained(
            "Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16
        )

        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=self.t5_encoder.config.to_dict()['d_model'], # 1024
            hidden_size=rnn_size,
            bidirectional=True,
            batch_first=True,
        )

        self.fc1 = nn.Linear(rnn_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, seq, att):
        with torch.no_grad():
            # input shape: (batch_size, seq_len=10)
            # out: (batch_size, seq_len+1, embedding_dim=1024)
            embedded = self.dropout(self.t5_encoder(seq, att).last_hidden_state)

        # input shape: (batch_size, seq_len+1, embedding_dim)
        out, _ = self.lstm(embedded)

        # input shape: (batch_size, seq_len=1, 2*rnn_size)
        pooled, _ = torch.max(out, dim=1)

        # input shape: (batch_size, 2*rnn_size)
        out = self.dropout(F.relu(self.fc1(pooled)))

        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size)
        return self.fc2(out).squeeze()

In [6]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, att, lbl = batch.seq, batch.att, batch.lbl
        seq, att, lbl = seq.to(device), att.to(device), lbl.to(device)
        
        with torch.cuda.amp.autocast():
            scores = model(seq, att)
            loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum()
        total += len(seq)
    return epoch_loss / total, num_correct / total

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load train and dev data
train_seqs, train_lbl = read_data("../../data/n_train.csv")
dev_seqs, dev_lbl = read_data("../../data/n_val.csv")

tokenizer = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

In [8]:
NUM_EPOCHS = 10
BATCH_SIZE = 512
RNN_SIZE = 512
HIDDEN_SIZE = 128
DROPOUT = 0.5
LEARNING_RATE = 1e-4

model = T5_BiLSTM(
    rnn_size=RNN_SIZE,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_bucket_sampler = BucketSampler(train_seqs, BATCH_SIZE)
train_sampler = BatchSampler(train_bucket_sampler, BATCH_SIZE, drop_last=True)
train_loader = DataLoader(train_data, batch_sampler=train_sampler, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_bucket_sampler = BucketSampler(dev_seqs, BATCH_SIZE)
dev_sampler = BatchSampler(dev_bucket_sampler, BATCH_SIZE, drop_last=True)
dev_loader = DataLoader(dev_data, batch_sampler=dev_sampler, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

print(f"Total model parameters: {get_total_model_params(model):,}")
print(f"Total trainable parameters: {get_total_trainable_model_params(model):,}")

Total model parameters: 1,214,572,801
Total trainable parameters: 6,430,977


In [9]:
# scale everything to fp16
scaler = torch.cuda.amp.GradScaler()

start = time()
print("Starting Training.")
highest_val_acc = 0
train_losses, train_accuracies= [], []
val_losses, val_accuracies = [], []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, train_acc = process(model, train_loader, criterion, optimizer)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_acc = process(model, dev_loader, criterion)
        
    # save current acc, loss
    train_losses.append((epoch, train_loss))
    train_accuracies.append((epoch, train_acc))
    val_losses.append((epoch, val_loss))
    val_accuracies.append((epoch, val_acc))
    
    if val_acc > highest_val_acc:
        highest_val_acc = val_acc
        path = f"../../params/n_term/t5encBiLSTM/acc{val_acc:.4f}_epoch{epoch}.pt"
        torch.save(model.state_dict(), path)
        
    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}]")
    
print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train: 100%|████████████████████████████████████████████████████| 2235/2235 [11:39<00:00,  3.20batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 279/279 [01:21<00:00,  3.41batches/s]
Training:   [Epoch  1, Loss:   0.0012, Acc: 0.6542]
Evaluation: [Epoch  1, Loss:   0.0011, Acc: 0.6859]
Train: 100%|████████████████████████████████████████████████████| 2235/2235 [11:50<00:00,  3.15batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 279/279 [01:24<00:00,  3.29batches/s]
Training:   [Epoch  2, Loss:   0.0011, Acc: 0.6830]
Evaluation: [Epoch  2, Loss:   0.0011, Acc: 0.6903]
Train: 100%|████████████████████████████████████████████████████| 2235/2235 [12:03<00:00,  3.09batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 279/279 [01:24<00:00,  3.30batches/s]
Training:   [Epoch  3, Loss:   0.0011, Acc: 0.6869]
Evaluation: [Epoch  3, Loss:   0.0011, Acc: 0.6935]
Train: 100%|█████████████████████

In [None]:
# save training stats
lsts = [train_losses, train_accuracies, val_losses, val_accuracies, train_time]
names = [
    "train_losses",
    "train_accuracies",
    "val_losses",
    "val_accuracies",
    "train_time",
]
to_save = dict()
for name, lst in zip(names, lsts):
    to_save[name] = lst

with open(f"../params/n_term/quadBiLSTM/metrics.pkl", "wb") as f:
    pickle.dump(to_save, f, pickle.HIGHEST_PROTOCOL)

print("Finished Saving Details.")