# Sources
* BiLSTM model architecture based on [Ozols et. al., 2021](https://www.mdpi.com/1422-0067/22/6/3071/htm)
* T5 Encoder taken from [Elnagger et al., 2020](https://ieeexplore.ieee.org/document/9477085), [Github](https://github.com/agemagician/ProtTrans), Model on [Huggingface Hub](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)


In [1]:
import os
import sys
import csv
import math
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import T5Tokenizer, T5EncoderModel

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, "r") as csvfile:
        train_data = list(csv.reader(csvfile))[1:]  # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls


def apply_random_masking(seq, att, num_tokens):
    """
    Mask `num_tokens` as 0 at random positions per sequence.
    """
    dist = torch.rand(seq.shape)
    m, _ = torch.topk(dist, num_tokens)
    return seq * (dist < m), att * (dist < m)


def regularized_auc(train_auc, dev_auc, threshold=0.0025):
    """
    Returns development AUC if overfitting is below threshold, otherwise 0.
    """
    return dev_auc if (train_auc - dev_auc) < threshold else 0


def save_metrics(*args, path):
    if not os.path.isfile(path):
        with open(path, "w", newline="\n") as f:
            f.write(
                ",".join(
                    [
                        "fold",
                        "epoch",
                        "train_loss",
                        "train_acc",
                        "train_auc",
                        "val_loss",
                        "val_acc",
                        "val_auc",
                    ]
                )
            )
            f.write("\n")
    if args:
        with open(path, "a", newline="\n") as f:
            f.write(",".join([str(arg) for arg in args]))
            f.write("\n")


def gelu(x):
    """
    Facebook Research implementation of the gelu activation function.
    
    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def trainable_model_params(model):
    return sum(
        p[1].numel()
        for p in model.named_parameters()
        if p[1].requires_grad and not p[0].startswith("t5")
    )

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]

    def __len__(self):
        return len(self.lbl)
    
    
class TrainBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        encoded = tokenizer.batch_encode_plus(
            [seq.replace("", " ").strip() for seq in ordered_batch[0]]
        )
        seq = torch.tensor(encoded["input_ids"], dtype=torch.int64)
        att = torch.tensor(encoded["attention_mask"], dtype=torch.int64)
        self.seq, self.att = apply_random_masking(seq, att, num_tokens=1)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.att = self.att.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def train_wrapper(batch):
    return TrainBatch(batch)
    

class EvalBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        encoded = tokenizer.batch_encode_plus(
            [seq.replace("", " ").strip() for seq in ordered_batch[0]]
        )
        self.seq = torch.tensor(encoded["input_ids"], dtype=torch.int64)
        self.att = torch.tensor(encoded["attention_mask"], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.att = self.att.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def eval_wrapper(batch):
    return EvalBatch(batch)

In [5]:
class T5_BiLSTM(nn.Module):
    def __init__(self, rnn_size, hidden_size, dropout):
        super().__init__()

        self.t5_encoder = T5EncoderModel.from_pretrained(
            "Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16
        )

        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=self.t5_encoder.config.to_dict()['d_model'], # 1024
            hidden_size=rnn_size,
            bidirectional=True,
            batch_first=True,
        )

        self.fc1 = nn.Linear(rnn_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, seq, att):
        with torch.no_grad():
            # input shape: (batch_size, seq_len=10)
            # out: (batch_size, seq_len+1, embedding_dim=1024)
            embedded = self.dropout(self.t5_encoder(seq, att).last_hidden_state)

        # input shape: (batch_size, seq_len+1, embedding_dim)
        out, _ = self.lstm(embedded)

        # input shape: (batch_size, seq_len=1, 2*rnn_size)
        pooled, _ = torch.max(out, dim=1)

        # input shape: (batch_size, 2*rnn_size)
        out = self.dropout(gelu(self.fc1(pooled)))

        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size)
        return self.fc2(out).squeeze()

In [6]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    preds, lbls = [], []
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, att, lbl = batch.seq, batch.att, batch.lbl
        seq, att, lbl = seq.to(device), att.to(device), lbl.to(device)
        
        with torch.cuda.amp.autocast():
            scores = model(seq, att)
            loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum().item()
        total += seq.shape[0]
        preds.extend(scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
    return epoch_loss / total, num_correct / total, roc_auc_score(lbls, preds)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 512

tokenizer = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

# load train+dev set, mix it back as one
train_path = '../../data/c_train.csv'
dev_path = '../../data/c_val.csv'
test_path = '../../data/c_test.csv'

# combine previously split train and dev set
train_seqs, train_lbls = read_data(train_path)
dev_seqs, dev_lbls = read_data(dev_path)
total_seqs, total_lbls = np.array(train_seqs + dev_seqs), np.array(train_lbls + dev_lbls)

assert len(train_seqs) + len(dev_seqs) == len(total_seqs)
assert len(train_lbls) + len(dev_lbls) == len(total_lbls)

test_seqs, test_lbls = read_data(test_path)

test_data = CleavageDataset(test_seqs, test_lbls)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=eval_wrapper, pin_memory=True, num_workers=10)

In [8]:
RNN_SIZE = 512
HIDDEN_SIZE = 128
DROPOUT = 0.5
LEARNING_RATE = 3e-4

params = {
    "rnn_size": RNN_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "dropout": DROPOUT
}

In [9]:
# scale everything to fp16
scaler = torch.cuda.amp.GradScaler()

NUM_EPOCHS = 10
LEARNING_RATE = 3e-4
criterion = nn.BCEWithLogitsLoss()

kf = KFold(n_splits=5, shuffle=True, random_state=1234)
logging_path = "../../params/c_term/T5_BiLSTM/results.csv"

start = time()
print("Starting Cross-Validation.")
highest_val_auc = 0

# get a new split
for fold, (train_idx, dev_idx) in enumerate(kf.split(total_seqs), 1):
    X_tr = total_seqs[train_idx]
    y_tr = total_lbls[train_idx]
    X_dev = total_seqs[dev_idx]
    y_dev = total_lbls[dev_idx]

    # create datasets and loads with current split
    train_data = CleavageDataset(X_tr, y_tr)
    train_loader = DataLoader(
        train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=train_wrapper,
        pin_memory=True,
        num_workers=10,
    )

    dev_data = CleavageDataset(X_dev, y_dev)
    dev_loader = DataLoader(
        dev_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=eval_wrapper,
        pin_memory=True,
        num_workers=10,
    )

    # reset model weights with each new fold
    if 'model' in globals():
        del model
        print('deleted model')
    model = T5_BiLSTM(**params).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # normal training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss, train_acc, train_auc = process(
            model, train_loader, criterion, optimizer
        )

        model.eval()
        with torch.no_grad():
            val_loss, val_acc, val_auc = process(model, dev_loader, criterion)

        # save metrics
        save_metrics(
            fold,
            epoch,
            train_loss,
            train_acc,
            train_auc,
            val_loss,
            val_acc,
            val_auc,
            path=logging_path,
        )

        print(
            f"Training:   [Fold {fold:2d}, Epoch {epoch:2d}, Loss: {train_loss:8.6f}, Acc: {train_acc:.4f}, AUC: {train_auc:.4f}]"
        )
        print(f"Evaluation: [Fold {fold:2d}, Epoch {epoch:2d}, Loss: {val_loss:8.6f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]")
        
        reg_auc = regularized_auc(train_auc, val_auc, threshold=0)
        if reg_auc > highest_val_auc:
            highest_val_auc = reg_auc
            path = f"../../params/c_term/T5_BiLSTM/auc{reg_auc:.4f}_fold{fold}_epoch{epoch}.pt"
            torch.save(model.state_dict(), path)

print("Finished Cross-Validation.")
train_time = (time() - start) / 60
print(f"Cross-Validation took {train_time} minutes.")

Starting Cross-Validation.
Train: 100%|████████████████████████████████████████████████████| 1997/1997 [10:27<00:00,  3.18batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 500/500 [02:26<00:00,  3.42batches/s]
Training:   [Fold  1, Epoch  1, Loss: 0.000723, Acc: 0.8416, AUC: 0.8086]
Evaluation: [Fold  1, Epoch  1, Loss: 0.000632, Acc: 0.8592, AUC: 0.8657]
Train: 100%|████████████████████████████████████████████████████| 1997/1997 [10:24<00:00,  3.20batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 500/500 [02:26<00:00,  3.42batches/s]
Training:   [Fold  1, Epoch  2, Loss: 0.000671, Acc: 0.8507, AUC: 0.8436]
Evaluation: [Fold  1, Epoch  2, Loss: 0.000621, Acc: 0.8626, AUC: 0.8713]
Train: 100%|████████████████████████████████████████████████████| 1997/1997 [10:25<00:00,  3.19batches/s]
Eval: 100%|███████████████████████████████████████████████████████| 500/500 [02:25<00:00,  3.43batches/s]
Training:   [Fold  1, Epoch  3, Loss: 0.0

In [10]:
if 'model' in globals():
    del model
    print('deleted model')

model = T5_BiLSTM(**params).to('cpu')
criterion = nn.BCEWithLogitsLoss()

# load best model, evaluate on test set
best_model = sorted(
    [f for f in os.listdir("../../params/c_term/T5_BiLSTM/") if f.endswith(".pt")],
    reverse=True,
)[0]
print("Loaded model: ", best_model)
checkpoint = torch.load('../../params/c_term/T5_BiLSTM/' + best_model, map_location='cpu')
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()

with torch.no_grad():
    test_loss, test_acc, test_auc = process(model, test_loader, criterion)
print(
    f"Test Set Performance: Loss: {test_loss:.6f}, Acc: {test_acc:.4f}, AUC: {test_auc:.4f}"
)

print(
    f"Total model params: {total_model_params(model)}, trainable model params: {trainable_model_params(model)}"
)

deleted model
Loaded model:  auc0.8815_fold4_epoch10.pt
Eval: 100%|███████████████████████████████████████████████████████| 278/278 [01:19<00:00,  3.48batches/s]
Test Set Performance: Loss: 0.000597, Acc: 0.8689, AUC: 0.8832
Total model params: 1214572801, trainable model params: 6430977
