In [None]:
import os
import sys
import csv
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import roc_auc_score

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [None]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

# number of parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.lbl[idx], self.seq[idx]

    def __len__(self):
        return len(self.lbl)


class CleavageBatch:
    def __init__(self, batch):
        lbls, _, self.seq = tokenizer(batch)
        self.lbl = torch.tensor([int(l) for l in lbls], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def collate_wrapper(batch):
    return CleavageBatch(batch)

In [None]:
class ESM2Finetune(nn.Module):
    def __init__(self, pretrained_model, dropout):
        super().__init__()

        self.esm2 = pretrained_model

        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Linear(33, 1)

    def forward(self, seq):
        # input shape: (batch_size, seq_len=10+2 (cls, eos))
        result = self.esm2(seq)['logits'][:, 1:10+1, :] # remove cls, eos token position
        result = self.dropout(result)
        
        # in: (batch_size, seq_len, vocab_size=33)
        result, _ = result.max(dim=1)

        # input shape: (batch_size, 33)
        # out shape: (batch_size)
        return self.fc(result).squeeze()

In [None]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    preds, lbls = [], []
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, lbl = batch.seq, batch.lbl
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores = model(seq)
        loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum().item()
        total += len(seq)
        preds.extend(scores.detach().tolist())
        lbls.extend(lbl.detach().tolist())
    return epoch_loss / total, num_correct / total, roc_auc_score(lbls, preds)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load pre-trained esm1b model and vocab
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t33_650M_UR50D')

# load train and dev data
train_seqs, train_lbl = read_data("../../data/n_train.csv")
dev_seqs, dev_lbl = read_data("../../data/n_val.csv")

tokenizer = vocab.get_batch_converter()

In [None]:
NUM_EPOCHS = 1
BATCH_SIZE = 128
DROPOUT = 0.
LEARNING_RATE = 2e-5

model = ESM2Finetune(
    pretrained_model=esm2,
    dropout=DROPOUT
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_loader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

print(f"Total model parameters: {get_total_model_params(model):,}")

In [None]:
start = time()
print("Starting Training.")
highest_val_auc = 0
train_losses, train_accuracies, train_aucs = [], [], []
val_losses, val_accuracies, val_aucs = [], [], []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, train_acc, train_auc = process(model, train_loader, criterion, optimizer)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_acc, val_auc = process(model, dev_loader, criterion)
        
    # save current acc, loss
    train_losses.append((epoch, train_loss))
    train_accuracies.append((epoch, train_acc))
    train_aucs.append((epoch, train_auc))
    val_losses.append((epoch, val_loss))
    val_accuracies.append((epoch, val_acc))
    val_aucs.append((epoch, val_auc))
    
    if val_auc > highest_val_auc:
        highest_val_auc = val_auc
        path = f"../../params/n_term/esm2finetune/auc{val_auc:.4f}_epoch{epoch}.pt"
        torch.save(model.state_dict(), path)
        
    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}, AUC: {train_auc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]")
    
print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

In [None]:
# save training stats
lsts = [train_losses, train_accuracies, val_losses, val_accuracies, train_aucs, val_aucs, train_time]
names = [
    "train_losses",
    "train_accuracies",
    "val_losses",
    "val_accuracies",
    "train_aucs",
    "val_aucs",
    "train_time",
]
to_save = dict()
for name, lst in zip(names, lsts):
    to_save[name] = lst

with open(f"../params/n_term/esm2finetune/metrics.pkl", "wb") as f:
    pickle.dump(to_save, f, pickle.HIGHEST_PROTOCOL)

print("Finished Saving Details.")