* https://github.com/agemagician/ProtTrans/tree/master/Fine-Tuning
* https://github.com/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert_BFD_FineTuning_MS.ipynb
* https://github.com/huggingface/transformers/blob/04976a32dc555667afa994e8f918cbee88d84a4f/src/transformers/models/bert/modeling_bert.py#L1481
    * Build the classification head yourself ontop of T5 decoder hidden states
* https://huggingface.co/docs/transformers/training#train-in-native-pytorch

In [1]:
import os
import sys
import csv
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import T5Tokenizer, T5Model, get_scheduler

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

# number of trainable parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]

    def __len__(self):
        return len(self.lbl)
    

class CleavageBatch:
    def __init__(self, batch):
        ordered_batch = list(zip(*batch))
        encoded = tokenizer.batch_encode_plus(
            [seq.replace("", " ").strip() for seq in ordered_batch[0]]
        )
        self.seq = torch.tensor(encoded["input_ids"], dtype=torch.int64)
        self.att = torch.tensor(encoded["attention_mask"], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.att = self.att.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def collate_wrapper(batch):
    return CleavageBatch(batch)

In [5]:
class T5FineTuner(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        
        self.t5 = T5Model.from_pretrained(model_name)
        self.dropout = nn.Dropout(self.t5.config.dropout_rate)
        self.fc = nn.Linear(self.t5.config.d_model, 1)
        
    def forward(self, input_ids, attention_mask, decoder_input_ids):
        outputs = self.t5(input_ids, attention_mask, decoder_input_ids)
        pooled = self.dropout(outputs.last_hidden_state)
        return self.fc(pooled).view(-1)

In [6]:
def process(model, loader, criterion, lr_scheduler=None, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    
    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches"
    ):
        seq, att, lbl = batch.seq, batch.att, batch.lbl
        seq, att, lbl, decoder_input = seq.to(device), att.to(device), lbl.to(device), lbl.to(torch.int64).unsqueeze(1).to(device)
        
        with torch.cuda.amp.autocast():
            scores = model(seq, att, decoder_input)
            loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            lr_scheduler.step()
            scaler.update()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum()
        total += len(seq)
    return epoch_loss / total, num_correct / total

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "Rostlab/prot_t5_xl_uniref50"

# load train and dev data
train_seqs, train_lbl = read_data("../../data/n_val.csv")
dev_seqs, dev_lbl = read_data("../../data/n_val.csv")

tokenizer = T5Tokenizer.from_pretrained(
    MODEL_NAME, do_lower_case=False
)

In [8]:
NUM_EPOCHS = 1
BATCH_SIZE = 1
LEARNING_RATE = 5e-5

model = T5FineTuner(
    model_name=MODEL_NAME
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_loader = DataLoader(dev_data, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_wrapper, pin_memory=True, num_workers=10)

lr_scheduler = get_scheduler(
    name='linear',
    optimizer=optimizer,
    num_warmup_steps=1000,
    num_training_steps=NUM_EPOCHS*len(train_loader)
)

print(f"Total trainable parameters: {get_total_model_params(model):,}")

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5Model: ['lm_head.weight']
- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Total trainable parameters: 2,818,831,361


In [9]:
# scale everything to fp16
scaler = torch.cuda.amp.GradScaler()

start = time()
print("Starting Training.")
highest_val_acc = 0
train_losses, train_accuracies= [], []
val_losses, val_accuracies = [], []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, train_acc = process(model, train_loader, criterion, lr_scheduler, optimizer)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_acc = process(model, dev_loader, criterion)
        
    # save current acc, loss
    train_losses.append((epoch, train_loss))
    train_accuracies.append((epoch, train_acc))
    val_losses.append((epoch, val_loss))
    val_accuracies.append((epoch, val_acc))
    
    if val_acc > highest_val_acc:
        highest_val_acc = val_acc
        path = f"../../params/n_term/t5finetune/acc{val_acc:.4f}_epoch{epoch}.pt"
        torch.save(model.state_dict(), path)
        
    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}]")
    
print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train:   0%|                                                             | 0/143059 [00:01<?, ?batches/s]


RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 23.69 GiB total capacity; 20.73 GiB already allocated; 33.94 MiB free; 21.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF