In [1]:
import os
import csv
import json
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

from functools import partial
from torchtext.vocab import build_vocab_from_iterator

from tokenizers import ByteLevelBPETokenizer, BertWordPieceTokenizer
from transformers import T5Tokenizer

from loaders import CleavageLoader
from denoise import NoiseAdaptation, CoteachingLoss, JoCoRLoss

from models import (
    BiLSTM, BiLSTMDivideMix,
    BiLSTMPadded, BiLSTMPaddedDivideMix,
    BiLSTMAttention, BiLSTMAttentionDivideMix,
    BiLSTMProt2Vec, BiLSTMProt2VecDivideMix,
    CNNAttention, CNNAttentionDivideMix,
    MLP, MLPDivideMix,
    ESM2BiLSTM, ESM2BiLSTMDivideMix,
    ESM2, ESM2DivideMix,
    T5BiLSTM, T5BiLSTMDivideMix
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

In [3]:
train_data = read_data('../data/c_train.csv')
val_data = read_data('../data/c_val.csv')
test_data = read_data('../data/c_test.csv')

In [None]:
# run tests for bbpm, wp, especially same performance on test set after reload, did correct fold/vocab get reloaded? - done
# run test for some k-fold epochs - dnoe
# run test for T5 with scaler stuff
# run test with ESM
# run tests with coteaching
# run tests with coteaching plus
# run tests with jocor
# load prot2vec
# run prot2vec


In [None]:
# if denoising method is taken, add to naming path - done

In [None]:
# which dtype are labels in NAD?
#    should be .long()
#    don't forget --nad flag

In [None]:
# c_bilstm3.cfg (with nad) had error that it trained with CEL, but it didn't to argmax


# if we are running with NAD, we need to calculate argmax, not logits > 0
# also need to calculate pos_pred then
# then calculate roc_auc_score with pos_preds

In [None]:
# check how we implemented loaders nad?
# what does args.nad return? Is it a string or not?

In [None]:
# change the test loader to also have nad if nad is given

In [None]:
# when training coteaching:
# run_epochs_coteach gets cot_criterion and criterion
# but in training, criterion needs to be None

In [None]:
# make early-stopping epochs controllable via

In [None]:
# MLP loads wrong model? why is there an embedding

In [None]:
# optim with duplicate parameters? but only for ESM?

# Model tests

In [None]:
def read_embeddings(path):
    with open(path, 'r') as f:
        seq, vec = [], []
        for line in f.readlines()[2:]: # skip first special chars
            lst = line.split()
            seq.append(lst[0].upper())
            vec.append([float(i) for i in lst[1:]])
        vocab = {s: i for i, s in enumerate(seq)}
        prot2vec = torch.tensor(vec, dtype=torch.float)
    return vocab, prot2vec

In [None]:
x1 = torch.randint(0, 20, (512, 10))
x1_1 = torch.randint(0, 20, (512, 10))
x2 = torch.randint(0, 20, (512, 10))
x2_1 = torch.randint(0, 20, (512, 10))
att = torch.ones_like(x1)
lam = 0.6
length = torch.arange(1, 513)
lengths = torch.tensor([4, 5, 6, 7] * 128)

In [None]:
VOCAB_SIZE = 20
EMBEDDING_DIM = 128
RNN_SIZE1 = 123
RNN_SIZE2 = 222
HIDDEN_SIZE = 100
DROPOUT = 0.5
OUT1 = 1
OUT2 = 2

In [None]:
bilstm = BiLSTM(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, RNN_SIZE2, HIDDEN_SIZE, DROPOUT, OUT1)
bilstmdividemix = BiLSTMDivideMix(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, RNN_SIZE2, HIDDEN_SIZE, DROPOUT, OUT2)

print(bilstm(x1).shape)
print(bilstmdividemix(x1).shape)
print(bilstmdividemix(x1, x2, lam, interpolate=True).shape)

In [None]:
bilstmpadded = BiLSTMPadded(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, RNN_SIZE2, HIDDEN_SIZE, DROPOUT, OUT1, pad_idx=1)
bilstmpaddeddividemix = BiLSTMPaddedDivideMix(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, RNN_SIZE2, HIDDEN_SIZE, DROPOUT, OUT2, pad_idx=1)

print(bilstmpadded(x1, lengths).shape)
print(bilstmpaddeddividemix(x1, lengths).shape)
print(bilstmpaddeddividemix(x1, lengths, x2, lengths, lam, interpolate=True).shape)

In [None]:
bilstmattention = BiLSTMAttention(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, HIDDEN_SIZE, 1, DROPOUT, OUT1)
bilstmattentiondividemix = BiLSTMAttentionDivideMix(VOCAB_SIZE, EMBEDDING_DIM, RNN_SIZE1, HIDDEN_SIZE, 1, DROPOUT, OUT2)

print(bilstmattention(x1).shape)
print(bilstmattentiondividemix(x1).shape)
print(bilstmattentiondividemix(x1, x2, lam, interpolate=True).shape)

In [None]:
vocab, embeddings = read_embeddings("../params/uniref_3M.vec")

In [None]:
tokenizer = lambda seq: [vocab.get(s, 0) for s in seq.split()]

In [None]:
loader = CleavageLoader(train_data, val_data, test_data, tokenizer, 32, 4)
train_loader, _, _ = loader.load("BiLSTM", nad=False, unk_idx=0)

In [None]:
for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break

In [None]:
vocab

In [None]:
seq1

In [None]:
bilstmprot2vec = BiLSTMProt2Vec(embeddings, RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT1)
bilstmprot2vecdividemix = BiLSTMProt2VecDivideMix(embeddings, RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT2)

print(bilstmprot2vec(x1).shape)
print(bilstmprot2vecdividemix(x1, x2, lam, interpolate=True).shape)
print(bilstmprot2vecdividemix(x1).shape)

In [None]:
cnn = CNNAttention(10, 2, 2, 4, 4, 5, 5, 2, 2, 10, 10, 1, 1, 23, 23, DROPOUT, OUT1)
cnndividemix = CNNAttentionDivideMix(10, 2, 2, 4, 4, 5, 5, 2, 2, 10, 10, 1, 1, 23, 23, DROPOUT, OUT2)

print(cnn(x1.float()).shape)
print(cnndividemix(x1.float()).shape)
print(cnndividemix(x1.float(), x2.float(), lam, interpolate=True).shape)

In [None]:
mlp = MLP(VOCAB_SIZE, 10, HIDDEN_SIZE, DROPOUT, OUT1)
mlpdividemix = MLPDivideMix(VOCAB_SIZE, 10, HIDDEN_SIZE, DROPOUT, OUT2)

print(mlp(torch.nn.functional.one_hot(x1).view(512, -1).float()).shape)
print(mlpdividemix(torch.nn.functional.one_hot(x1).view(512, -1).float()).shape)
print(mlpdividemix(torch.nn.functional.one_hot(x1).view(512, -1).float(), torch.nn.functional.one_hot(x2).view(512, -1).float(), lam, interpolate=True).shape)

In [None]:
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t30_150M_UR50D')

esm2bilstm = ESM2BiLSTM(esm2, RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT1)
esm2bilstmdividemix = ESM2BiLSTMDivideMix(esm2, RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT2)

print(esm2bilstm(x1).shape)
print(esm2bilstmdividemix(x1).shape)
print(esm2bilstmdividemix(x1, x2, lam, interpolate=True).shape)

In [None]:
with torch.no_grad():
    esm2model = ESM2(esm2, DROPOUT, OUT1)
    esm2dividemix = ESM2DivideMix(esm2, DROPOUT, OUT2)

    print(esm2model(x1).shape)
    print(esm2dividemix(x1).shape)
    print(esm2dividemix(x1, x2, lam, interpolate=True).shape)

In [None]:
t5 = T5BiLSTM(RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT1).cuda()
t5dividemix = T5BiLSTMDivideMix(RNN_SIZE1, HIDDEN_SIZE, DROPOUT, OUT2).cuda()

print(t5(x1.long().cuda(), att.cuda()).shape)
print(t5dividemix(x1.long().cuda(), att.cuda()).shape)
print(t5dividemix(x1.long().cuda(), att.cuda(), x2.long().cuda(), att.cuda(), lam, interpolate=True).shape)

# Losses

In [None]:
class NoiseAdaptation(nn.Module):
    def __init__(self, theta, k):
        super().__init__()
        self.theta = nn.Linear(k, k, bias=False)
        self.theta.weight.data = theta
        self.eye = torch.eye(k)
        
    def forward(self, x):
        theta = self.theta(self.eye)
        theta = torch.softmax(theta, dim=0)
        out = x @ theta
        return out
    
class NoiseAdaptation2(nn.Module):
    def __init__(self, theta, k):
        super().__init__()
        self.theta = nn.Linear(k, k, bias=False)
        self.theta.weight.data = theta
        self.eye = torch.eye(k)
        
    def forward(self, x):
        theta = self.theta(self.eye)
        theta = torch.softmax(theta, dim=0)
        out = torch.matmul(x, theta)
        return out

In [None]:
theta = torch.randn(2, 2)
theta = theta / theta.sum(dim=1, keepdim=True)
nad = NoiseAdaptation(theta, 2)
nad2 = NoiseAdaptation2(theta, 2)

In [None]:
x = torch.randn(512, 2)

In [None]:
nad(x) == nad2(x)

In [None]:
import numpy as np

def loss_coteaching(y_1, y_2, t, forget_rate):
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    
    loss_1 = criterion(y_1, t)
    ind_1_sorted = np.argsort(loss_1.data.cpu())
    loss_1_sorted = loss_1[ind_1_sorted]

    loss_2 = criterion(y_2, t)
    ind_2_sorted = np.argsort(loss_2.data.cpu())
    loss_2_sorted = loss_2[ind_2_sorted]

    remember_rate = 1 - forget_rate
    num_remember = int(remember_rate * len(loss_1_sorted))

    ind_1_update = ind_1_sorted[:num_remember]
    ind_2_update = ind_2_sorted[:num_remember]
    
    # exchange
    loss_1_update = criterion(y_1[ind_2_update], t[ind_2_update])
    loss_2_update = criterion(y_2[ind_1_update], t[ind_1_update])

    return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember


class CoteachingLoss:
    def __init__(self):
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')

    def __call__(self, y1, y2, t, forget_rate):
        l1 = self.criterion(y1, t)
        idx1 = torch.argsort(l1)

        l2 = self.criterion(y2, t)
        idx2 = torch.argsort(l2)

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * l1.shape[0])

        idx1_update = idx1[:num_remember]
        idx2_update = idx2[:num_remember]

        # exchange the samples
        l1_update = self.criterion(y1[idx2_update], t[idx2_update])
        l2_update = self.criterion(y2[idx1_update], t[idx1_update])

        return l1_update.sum() / num_remember, l2_update.sum() / num_remember


cot = CoteachingLoss()

y1 = torch.randn(512)
y2 = torch.randn(512)
t = torch.randint(0, 2, (512,))
fgt = 0.2

In [None]:
a = loss_coteaching(y1, y2, t.float(), fgt)
b = cot(y1, y2, t.float(), fgt)

In [None]:
def kl(pred, lbl):
    return F.kl_div(F.logsigmoid(pred), F.sigmoid(lbl), reduction='sum')

def kl2(pred, lbl):
    return torch.sum(F.kl_div(F.logsigmoid(pred), F.sigmoid(lbl), reduction='none'))

In [None]:
kl(y1, t)

In [None]:
kl2(y1, t)

In [None]:
class JoCoRLoss:
    """
    Based on:

    Wei, H., Feng, L., Chen, X., & An, B. (2020).
    Combating noisy labels by agreement: A joint training method with co-regularization.
    In Proceedings of the IEEE/CVF conference on
    computer vision and pattern recognition (pp. 13726-13735).
    """
    def __init__(self):
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')
        self.co_lambda = 0.1

    def kl_loss(self, pred, soft_target):
        return F.kl_div(F.logsigmoid(pred), F.sigmoid(soft_target), reduction='sum') 

    def __call__(self, y1, y2, lbls, forget_rate):
        l1 = self.criterion(y1, lbls) * (1 - self.co_lambda)
        l2 = self.criterion(y2, lbls) * (1 - self.co_lambda)
        losses = l1 + l2 + (self.co_lambda * self.kl_loss(y1, y2)) + (self.co_lambda * self.kl_loss(y2, y1))

        idx = torch.argsort(losses)
        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * losses.shape[0])

        idx_update = idx[:num_remember]
        loss = losses[idx_update].mean()
        return loss, loss
    
def kl_loss_compute(pred, soft_targets):
    # adjusted for binary case
    kl = F.kl_div(F.logsigmoid(pred), torch.sigmoid(soft_targets), reduction='none')
    return torch.sum(kl)


class JoCoRLoss2:
    def __call__(self, y1, y2, lbls, forget_rate, loss_func, kl_loss, co_lambda=0.1):
        loss_pick_1 = loss_func(y1, lbls) * (1 - co_lambda)
        loss_pick_2 = loss_func(y2, lbls) * (1 - co_lambda)
        loss_pick = (
            loss_pick_1
            + loss_pick_2
            + co_lambda * kl_loss_compute(y1, y2)
            + co_lambda * kl_loss_compute(y2, y1)
        ).cpu()

        ind_sorted = np.argsort(loss_pick.data)
        loss_sorted = loss_pick[ind_sorted]

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_sorted))

        ind_update = ind_sorted[:num_remember]

        loss = torch.mean(loss_pick[ind_update])

        return loss, loss

In [None]:
jocor = JoCoRLoss()
jocor2 = JoCoRLoss2()

In [None]:
a = jocor(y1, y2, t.float(), 0.2)
b = jocor2(y1, y2, t.float(), 0.2, nn.BCEWithLogitsLoss(reduction='none'), kl_loss_compute)

In [None]:
a

In [None]:
b

# Collators

In [None]:
def apply_random_masking(seq, unk_idx):
    """
    Mask `seq_len // 10` tokens as UNK at random positions per sequence. 
    """
    num_samples, seq_len = seq.shape
    mask_idx = torch.randint(0, seq_len, (num_samples, seq_len // 10))
    masked_seq = torch.scatter(seq, 1, mask_idx, unk_idx)
    return masked_seq

In [None]:
class CollateFunctions:
    def __init__(self):
        pass

    def collate_fn(self, model_name: str, denoising_method: str, training_mode: str) -> Callable:
        if model_name == "model1":
            return functools.partial(self._collate_model1, denoising_method, training_mode)
        elif model_name == "model2":
            return functools.partial(self._collate_model2, denoising_method, training_mode)
        elif model_name == "model3":
            return functools.partial(self._collate_model3, denoising_method, training_mode)
        else:
            raise ValueError(f"Unsupported model name: {model_name}")

    def _collate_model1(self, denoising_method: str, training_mode: str, batch) -> Any:
        # Collation logic for model1
        # Use denoising_method and training_mode as needed
        ...

    def _collate_model2(self, denoising_method: str, training_mode: str, batch) -> Any:
        # Collation logic for model2
        # Use denoising_method and training_mode as needed
        ...

    def _collate_model3(self, denoising_method: str, training_mode: str, batch) -> Any:
        # Collation logic for model3
        # Use denoising_method and training_mode as needed
        ...

# Usage example:
collate_fns = CollateFunctions()
model_name = "model1"
denoising_method = "denoise1"
training_mode = "train"

# Generate the collate function
collate_fn = collate_fns.collate_fn(model_name, denoising_method, training_mode)

# Use the collate function in a DataLoader
data_loader = DataLoader(dataset, collate_fn=collate_fn)


In [None]:
def collate_batch(batch, mode):
    ordered_batch = list(zip(*batch))
    seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
    lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
    
    if mode == 'train':
        seq = apply_random_masking(seq, unk_idx=0)

    return seq, lbl

# Create a partial function with additional arguments
collate_fn_train = partial(collate_batch, mode='train')
collate_fn_test = partial(collate_batch, mode='test')

In [None]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

train_seq, train_lbl = read_data('../data/c_train.csv')


class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
train_data = CleavageDataset(train_seq, train_lbl)

In [None]:
# create vocab from train seqs
vocab_base = build_vocab_from_iterator(train_seq, specials=['<UNK>'])
# vocab = build_vocab_from_iterator(train_seqs)
vocab_base.set_default_index(vocab_base['<UNK>'])
encode_text = lambda x: vocab_base(list(x))

# load pre-trained esm2 model and vocab
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t30_150M_UR50D')
tokenizer_esm = vocab.get_batch_converter()

tokenizer_t5 = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

In [None]:
# load all tokenizers here
bbpe1_vocab, bbpe1_merges = '../params/c_bbpe1k-vocab.json', '../params/c_bbpe1k-merges.txt'
bbpe50_vocab, bbpe50_merges = '../params/c_bbpe50k-vocab.json', '../params/c_bbpe50k-merges.txt'
wp50_vocab = '../params/c_wp50k-vocab.txt'

bbpe1 = ByteLevelBPETokenizer.from_file(bbpe1_vocab, bbpe1_merges, lowercase=False)
bbpe1.enable_padding(pad_id=1, pad_token='<PAD>')

bbpe50 = ByteLevelBPETokenizer.from_file(bbpe50_vocab, bbpe50_merges, lowercase=False)
bbpe50.enable_padding(pad_id=1, pad_token='<PAD>')

wp50 = BertWordPieceTokenizer(wp50_vocab)
wp50.enable_padding(pad_id=0, pad_token='[PAD]')

In [None]:
# test all functionality

In [None]:
from loaders import BatchCollator

In [None]:
base_collator = BatchCollator(encode_text)
esm2_collator = BatchCollator(tokenizer_esm)
t5_collator = BatchCollator(tokenizer_t5)
bbpe1_collator = BatchCollator(bbpe1)
bbpe50_collator = BatchCollator(bbpe50)
wp50_collator = BatchCollator(wp50)

In [None]:
# works for BiLSTM, BiLSTMAttentino, BiLSTMProt2Vec
base_train = base_collator.collate_fn('BiLSTM', nad=False, train=True, unk_idx=0)
base_test = base_collator.collate_fn('BiLSTMAttention', nad=False, train=False, unk_idx=0)
base_train_nad = base_collator.collate_fn('BiLSTM', nad=True, train=True, unk_idx=0)
base_test_nad = base_collator.collate_fn('BiLSTMProt2Vec', nad=True, train=False, unk_idx=0)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=base_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=base_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=base_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=base_test_nad, num_workers=4)

for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break
    
for seq, lbl in test_loader:
    seq2, lbl2 = seq, lbl
    break
    
for seq, lbl in train_loader_nad:
    seq1_nad, lbl1_nad = seq, lbl
    break
    
for seq, lbl in test_loader_nad:
    seq2_nad, lbl2_nad = seq, lbl
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert all([0 in s[n] for s in [seq1, seq1_nad] for n in range(s.shape[0])]), "masking didn't work"
assert not all([0 in s[n] for s in [seq2, seq2_nad] for n in range(s.shape[0])]), "masking found in test set"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for CNN
cnn_train = base_collator.collate_fn('CNN', nad=False, train=True, unk_idx=0)
cnn_test = base_collator.collate_fn('CNN', nad=False, train=False, unk_idx=0)
cnn_train_nad = base_collator.collate_fn('CNN', nad=True, train=True, unk_idx=0)
cnn_test_nad = base_collator.collate_fn('CNN', nad=True, train=False, unk_idx=0)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=cnn_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=cnn_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=cnn_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=cnn_test_nad, num_workers=4)

for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break
    
for seq, lbl in test_loader:
    seq2, lbl2 = seq, lbl
    break
    
for seq, lbl in train_loader_nad:
    seq1_nad, lbl1_nad = seq, lbl
    break
    
for seq, lbl in test_loader_nad:
    seq2_nad, lbl2_nad = seq, lbl
    break   

In [None]:
assert seq1.dtype == torch.float32, "wrong sequence dtype"
assert all([0 in s[n] for s in [seq1, seq1_nad] for n in range(s.shape[0])]), "masking didn't work"
assert not all([0 in s[n] for s in [seq2, seq2_nad] for n in range(s.shape[0])]), "masking found in test set"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for MLP
mlp_train = base_collator.collate_fn('MLP', nad=False, train=True, unk_idx=0)
mlp_test = base_collator.collate_fn('MLP', nad=False, train=False, unk_idx=0)
mlp_train_nad = base_collator.collate_fn('MLP', nad=True, train=True, unk_idx=0)
mlp_test_nad = base_collator.collate_fn('MLP', nad=True, train=False, unk_idx=0)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=mlp_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=mlp_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=mlp_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=mlp_test_nad, num_workers=4)

for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break
    
for seq, lbl in test_loader:
    seq2, lbl2 = seq, lbl
    break
    
for seq, lbl in train_loader_nad:
    seq1_nad, lbl1_nad = seq, lbl
    break
    
for seq, lbl in test_loader_nad:
    seq2_nad, lbl2_nad = seq, lbl
    break   

In [None]:
assert seq1.dtype == torch.float32, "wrong sequence dtype"
assert sum([
    s[sample][n][0] == 1
    for s in [seq1, seq1_nad]
    for sample in range(s.shape[0])
    for n in range(s.shape[1])
]) == 32, "masking didn't work"
assert sum([
    s[sample][n][0] == 1
    for s in [seq2, seq2_nad]
    for sample in range(s.shape[0])
    for n in range(s.shape[1])
]) == 0, "mask found in test batch"
assert seq1.sum() == seq2.sum(), "different UNK/data strucutre"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for BBP1
bbpe1_train = bbpe1_collator.collate_fn('BiLSTMPadded', nad=False, train=True, unk_idx=0)
bbpe1_test = bbpe1_collator.collate_fn('BiLSTMPadded', nad=False, train=False, unk_idx=0)
bbpe1_train_nad = bbpe1_collator.collate_fn('BiLSTMPadded', nad=True, train=True, unk_idx=0)
bbpe1_test_nad = bbpe1_collator.collate_fn('BiLSTMPadded', nad=True, train=False, unk_idx=0)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe1_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe1_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe1_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe1_test_nad, num_workers=4)

for seq, lbl, lengths in train_loader:
    seq1, lbl1, lengths1 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader:
    seq2, lbl2, lengths2 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in train_loader_nad:
    seq1_nad, lbl1_nad, lengths_nad = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader_nad:
    seq2_nad, lbl2_nad, lengths2_nad = seq, lbl, lengths
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert lengths1.dtype == torch.int64, "wrong lenghts dtype"
assert seq1.sum() == seq2.sum(), "different UNK/data strucutre"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for BBPE50
bbpe50_train = bbpe50_collator.collate_fn('BiLSTMPadded', nad=False, train=True, unk_idx=0)
bbpe50_test = bbpe50_collator.collate_fn('BiLSTMPadded', nad=False, train=False, unk_idx=0)
bbpe50_train_nad = bbpe50_collator.collate_fn('BiLSTMPadded', nad=True, train=True, unk_idx=0)
bbpe50_test_nad = bbpe50_collator.collate_fn('BiLSTMPadded', nad=True, train=False, unk_idx=0)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe50_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe50_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe50_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=bbpe50_test_nad, num_workers=4)

for seq, lbl, lengths in train_loader:
    seq1, lbl1, lengths1 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader:
    seq2, lbl2, lengths2 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in train_loader_nad:
    seq1_nad, lbl1_nad, lengths_nad = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader_nad:
    seq2_nad, lbl2_nad, lengths2_nad = seq, lbl, lengths
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert lengths1.dtype == torch.int64, "wrong lenghts dtype"
assert seq1.sum() == seq2.sum(), "different UNK/data strucutre"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for WP
wp_train = wp50_collator.collate_fn('BiLSTMPadded', nad=False, train=True, unk_idx=1)
wp_test = wp50_collator.collate_fn('BiLSTMPadded', nad=False, train=False, unk_idx=1)
wp_train_nad = wp50_collator.collate_fn('BiLSTMPadded', nad=True, train=True, unk_idx=1)
wp_test_nad = wp50_collator.collate_fn('BiLSTMPadded', nad=True, train=False, unk_idx=1)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=wp_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=wp_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=wp_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=wp_test_nad, num_workers=4)

for seq, lbl, lengths in train_loader:
    seq1, lbl1, lengths1 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader:
    seq2, lbl2, lengths2 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in train_loader_nad:
    seq1_nad, lbl1_nad, lengths_nad = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader_nad:
    seq2_nad, lbl2_nad, lengths2_nad = seq, lbl, lengths
    break    

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert lengths1.dtype == torch.int64, "wrong lenghts dtype"
assert seq1.sum() == seq2.sum(), "different UNK/data strucutre"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for ESM
esm2_train = esm2_collator.collate_fn('ESM2BiLSTM', nad=False, train=True, unk_idx=3)
esm2_test = esm2_collator.collate_fn('ESM2', nad=False, train=False, unk_idx=3)
esm2_train_nad = esm2_collator.collate_fn('ESM2', nad=True, train=True, unk_idx=3)
esm2_test_nad = esm2_collator.collate_fn('ESM2BiLSTM', nad=True, train=False, unk_idx=3)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=esm2_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=esm2_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=esm2_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=esm2_test_nad, num_workers=4)

for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break
    
for seq, lbl in test_loader:
    seq2, lbl2 = seq, lbl
    break
    
for seq, lbl in train_loader_nad:
    seq1_nad, lbl1_nad = seq, lbl
    break
    
for seq, lbl in test_loader_nad:
    seq2_nad, lbl2_nad = seq, lbl
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert all([3 in s[n] for s in [seq1, seq1_nad] for n in range(s.shape[0])]), "masking didn't work"
assert not all([3 in s[n] for s in [seq2, seq2_nad] for n in range(s.shape[0])]), "masking found in test set"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# works for T5
t5_train = t5_collator.collate_fn('T5BiLSTM', nad=False, train=True, unk_idx=2)
t5_test = t5_collator.collate_fn('T5BiLSTM', nad=False, train=False, unk_idx=2)
t5_train_nad = t5_collator.collate_fn('T5BiLSTM', nad=True, train=True, unk_idx=2)
t5_test_nad = t5_collator.collate_fn('T5BiLSTM', nad=True, train=False, unk_idx=2)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=t5_train, num_workers=4)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=t5_test, num_workers=4)
train_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=t5_train_nad, num_workers=4)
test_loader_nad = torch.utils.data.DataLoader(train_data, batch_size=16, collate_fn=t5_test_nad, num_workers=4)

for seq, att, lbl in train_loader:
    seq1, att1, lbl1 = seq, att, lbl
    break
    
for seq, att, lbl in test_loader:
    seq2, att2, lbl2 = seq, att, lbl
    break
    
for seq, att, lbl in train_loader_nad:
    seq1_nad, att1_nad, lbl1_nad = seq, att, lbl
    break
    
for seq, att, lbl in test_loader_nad:
    seq2_nad, att2_nad, lbl2_nad = seq, att, lbl
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert all([2 in s[n] for s in [seq1, seq1_nad] for n in range(s.shape[0])]), "masking didn't work"
assert not all([2 in s[n] for s in [seq2, seq2_nad] for n in range(s.shape[0])]), "masking found in test set"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert (lbl1 == lbl2).sum() == lbl1.shape[0], "lbl1 and lbl2 are not equal"
assert (lbl1_nad == lbl2_nad).sum() == lbl1.shape[0], "lbl1_nad and lbl2_nad are not equal"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

# Loaders

In [None]:
vocab_base = torch.load('../params/vocab.pt')
encode_text = lambda x: vocab_base(list(x))

# load pre-trained esm2 model and vocab
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t30_150M_UR50D')
tokenizer_esm = vocab.get_batch_converter()

tokenizer_t5 = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

# load all tokenizers here
bbpe1_vocab, bbpe1_merges = '../params/c_bbpe1k-vocab.json', '../params/c_bbpe1k-merges.txt'
bbpe50_vocab, bbpe50_merges = '../params/c_bbpe50k-vocab.json', '../params/c_bbpe50k-merges.txt'
wp50_vocab = '../params/c_wp50k-vocab.txt'

bbpe1 = ByteLevelBPETokenizer.from_file(bbpe1_vocab, bbpe1_merges, lowercase=False)
bbpe1.enable_padding(pad_id=1, pad_token='<PAD>')

bbpe50 = ByteLevelBPETokenizer.from_file(bbpe50_vocab, bbpe50_merges, lowercase=False)
bbpe50.enable_padding(pad_id=1, pad_token='<PAD>')

wp50 = BertWordPieceTokenizer(wp50_vocab)
wp50.enable_padding(pad_id=0, pad_token='[PAD]')

In [None]:
terminus = 'c'

train_data = read_data(f"../data/{terminus}_train.csv")
val_data = read_data(f"../data/{terminus}_val.csv")
test_data = read_data(f"../data/{terminus}_test.csv")

In [None]:
# BiLSTM
loader = CleavageLoader(train_data, val_data, test_data, encode_text, 16, 4)
train_loader, val_loader, test_loader = loader.load('BiLSTMProt2Vec', nad=False, unk_idx=0)
train_loader_nad, val_loader_nad, test_loader_nad = loader.load('BiLSTMAttention', nad=True, unk_idx=0)

In [None]:
for batch in train_loader:
    seq, lbl = batch.seq, batch.lbl

In [None]:
for seq, lbl in train_loader:
    seq1, lbl1 = seq, lbl
    break
    
for seq, lbl in test_loader:
    seq2, lbl2 = seq, lbl
    break
    
for seq, lbl in train_loader_nad:
    seq1_nad, lbl1_nad = seq, lbl
    break
    
for seq, lbl in test_loader_nad:
    seq2_nad, lbl2_nad = seq, lbl
    break   

In [None]:
assert seq1.dtype == torch.int64, "wrong sequence dtype"
assert all([0 in s[n] for s in [seq1, seq1_nad] for n in range(s.shape[0])]), "masking didn't work"
assert not all([0 in s[n] for s in [seq2, seq2_nad] for n in range(s.shape[0])]), "masking found in test set"
assert lbl1.dtype == torch.float32, "wrong lbl dtype"
assert lbl2_nad.dtype == torch.int64, "nad lbl dtype wrong"
assert lbl1_nad.dtype == torch.int64, "nad llb dtype wrong"

In [None]:
# BBPE50

loader = CleavageLoader(train_data, val_data, test_data, bbpe50, 16, 4)
train_loader, val_loader, test_loader = loader.load('BiLSTMPadded', nad=False, unk_idx=0)
train_loader_nad, val_loader_nad, test_loader_nad = loader.load('BiLSTMPadded', nad=True, unk_idx=0)

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

for seq, lbl, lengths in train_loader:
    seq1, lbl1, lengths1 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader:
    seq2, lbl2, lengths2 = seq, lbl, lengths
    break
    
for seq, lbl, lengths in train_loader_nad:
    seq1_nad, lbl1_nad, lengths_nad = seq, lbl, lengths
    break
    
for seq, lbl, lengths in test_loader_nad:
    seq2_nad, lbl2_nad, lengths2_nad = seq, lbl, lengths
    break   

# Train/Eval Processors

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
train_data = read_data('../data/c_train.csv')
val_data = read_data('../data/c_val.csv')
test_data = read_data('../data/c_test.csv')

vocab_base = torch.load('../params/vocab.pt')
encode_text = lambda x: vocab_base(list(x))

# load pre-trained esm2 model and vocab
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t30_150M_UR50D')
tokenizer_esm = vocab.get_batch_converter()

tokenizer_t5 = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

# load all tokenizers here
bbpe1_vocab, bbpe1_merges = '../params/c_bbpe1k-vocab.json', '../params/c_bbpe1k-merges.txt'
bbpe50_vocab, bbpe50_merges = '../params/c_bbpe50k-vocab.json', '../params/c_bbpe50k-merges.txt'
wp50_vocab = '../params/c_wp50k-vocab.txt'

bbpe1 = ByteLevelBPETokenizer.from_file(bbpe1_vocab, bbpe1_merges, lowercase=False)
bbpe1.enable_padding(pad_id=1, pad_token='<PAD>')

bbpe50 = ByteLevelBPETokenizer.from_file(bbpe50_vocab, bbpe50_merges, lowercase=False)
bbpe50.enable_padding(pad_id=1, pad_token='<PAD>')

wp50 = BertWordPieceTokenizer.from_file(wp50_vocab, lowercase=False)
wp50.enable_padding(pad_id=0, pad_token='[PAD]')

In [None]:
# normal loader, no padding
loader = CleavageLoader(train_data, val_data, test_data, encode_text, 512, 10)

In [None]:
train_loader, val_loader, test_loader = loader.load('BiLSTM', nad=False, unk_idx=0)

model = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_base(model, 'BiLSTM', train_loader, criterion, device, optimizer)

model.eval()
val_loss, val_acc, val_auc = train_or_eval_base(model, 'BiLSTM', val_loader, criterion, device)

train_auc, val_auc

In [None]:
train_loader_nad, val_loader_nad, test_loader_nad = loader.load('BiLSTM', nad=True, unk_idx=0)
model_nad = BiLSTM(21, 150, 128, 256, 64, 0.5, 2).to(device)
criterion2 = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_nad.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_nad(model_nad, 'BiLSTM', train_loader_nad, criterion2, device, conf=None, optim=optimizer)

model.eval()
val_loss, val_acc, val_auc = train_or_eval_nad(model_nad, 'BiLSTM', val_loader_nad, criterion2, device)

train_auc, val_auc

In [None]:
# test padded on nad, normal
# padded loader
loader = CleavageLoader(train_data, val_data, test_data, bbpe50, 512, 10)

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# normal loading
train_loader, val_loader, test_loader = loader.load('Padded', nad=False, unk_idx=0)

model = BiLSTMPadded(bbpe50.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, pad_idx=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_base(model, 'Padded', train_loader, criterion, device, optimizer)

model.eval()
val_loss, val_acc, val_auc = train_or_eval_base(model, 'Padded', val_loader, criterion, device)

train_auc, val_auc

In [None]:
# nad loading
train_loader, val_loader, test_loader = loader.load('Padded', nad=True, unk_idx=0)

model = BiLSTMPadded(bbpe50.get_vocab_size(), 150, 128, 256, 64, 0.5, 2, pad_idx=1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_nad(model, 'Padded', train_loader, criterion, device, conf=None, optim=optimizer)

model.eval()
val_loss, val_acc, val_auc = train_or_eval_nad(model, 'Padded', val_loader, criterion, device)

train_auc, val_auc

In [None]:
# test esm on nad and normal
# esm loader
loader = CleavageLoader(train_data, val_data, test_data, tokenizer_esm, 512, 10)

In [None]:
# normal loading
train_loader, val_loader, test_loader = loader.load('ESM2', nad=False, unk_idx=3)

model = ESM2(esm2, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_base(model, 'ESM2', test_loader, criterion, device, optimizer)

model.eval()
with torch.no_grad():
    val_loss, val_acc, val_auc = train_or_eval_base(model, 'ESM2', val_loader, criterion, device)

train_auc, val_auc

In [None]:
# nad loading
train_loader, val_loader, test_loader = loader.load('ESM2', nad=True, unk_idx=3)

model = ESM2(esm2, 0.5, 2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_nad(model, 'ESM2', test_loader, criterion, device, conf=None, optim=optimizer)

model.eval()
with torch.no_grad():
    val_loss, val_acc, val_auc = train_or_eval_nad(model, 'ESM2', val_loader, criterion, device)

train_auc, val_auc

In [None]:
# test t5 on nad and normal
loader = CleavageLoader(train_data, val_data, test_data, tokenizer_t5, 512, 10)

In [None]:
# scale everything to fp16
scaler = torch.cuda.amp.GradScaler()

# normal loading
train_loader, val_loader, test_loader = loader.load('T5', nad=False, unk_idx=2)

model = T5BiLSTM(512, 128, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_base(model, 'T5', train_loader, criterion, device, optimizer, scaler)

model.eval()
with torch.no_grad():
    val_loss, val_acc, val_auc = train_or_eval_base(model, 'T5', val_loader, criterion, device, scaler=scaler)

train_auc, val_auc

In [None]:
# scale everything to fp16
scaler = torch.cuda.amp.GradScaler()

# normal loading
train_loader, val_loader, test_loader = loader.load('T5', nad=True, unk_idx=2)

model = T5BiLSTM(512, 128, 0.5, 2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
train_loss, train_acc, train_auc = train_or_eval_nad(model, 'T5', test_loader, criterion, device, optim=optimizer, scaler=scaler)

model.eval()
with torch.no_grad():
    val_loss, val_acc, val_auc = train_or_eval_nad(model, 'T5', val_loader, criterion, device, scaler=scaler)

train_auc, val_auc

In [None]:
# test everything from above now with coteaching / plus
# check print statements especially if correct choice ran

In [None]:
# BiLSTM
loader = CleavageLoader(train_data, val_data, test_data, encode_text, 512, 10)

In [None]:
# coteaching
train_loader, val_loader, test_loader = loader.load('BiLSTM', nad=False, unk_idx=0)

model1 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
model2 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching training: only cot_criterion!
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='BiLSTM',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    optim1=optimizer1,
    optim2=optimizer2,
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='BiLSTM',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
    )

train_res, val_res

In [None]:
# coteaching plus
train_loader, val_loader, test_loader = loader.load('BiLSTM', nad=False, unk_idx=0)

model1 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
model2 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching plus training: only cot_criterion!
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='BiLSTM',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    criterion=criterion,
    optim1=optimizer1,
    optim2=optimizer2,
    cot_plus_train=True
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='BiLSTM',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
    )

train_res, val_res

In [None]:
# BBPE50
loader = CleavageLoader(train_data, val_data, test_data, wp50, 512, 10)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
# coteaching
train_loader, val_loader, test_loader = loader.load('Padded', nad=False, unk_idx=1)

model1 = BiLSTMPadded(wp50.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 0).to(device)
model2 = BiLSTMPadded(wp50.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 0).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching training: only cot_criterion!
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='Padded',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    optim1=optimizer1,
    optim2=optimizer2,
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='Padded',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
    )

train_res, val_res

In [None]:
# coteaching plus
train_loader, val_loader, test_loader = loader.load('Padded', nad=False, unk_idx=1)

model1 = BiLSTMPadded(wp50.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 0).to(device)
model2 = BiLSTMPadded(wp50.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 0).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching plus training: both criterions needed
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='Padded',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    criterion=criterion,
    optim1=optimizer1,
    optim2=optimizer2,
    cot_plus_train=True
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='Padded',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
    )

train_res, val_res

In [None]:
# T5
loader = CleavageLoader(train_data, val_data, test_data, tokenizer_t5, 512, 10)
scaler1 = torch.cuda.amp.GradScaler()
scaler2 = torch.cuda.amp.GradScaler()

In [None]:
# coteaching
train_loader, val_loader, test_loader = loader.load('T5', nad=False, unk_idx=2)

model1 = T5BiLSTM(512, 128, 0.5, 1).to(device)
model2 = T5BiLSTM(512, 128, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching training: only cot_criterion!
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='T5',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    scaler1=scaler1,
    scaler2=scaler2,
    optim1=optimizer1,
    optim2=optimizer2,
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='T5',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
        scaler1=scaler1,
        scaler2=scaler2,
    )

train_res, val_res

In [None]:
# coteaching plus
train_loader, val_loader, test_loader = loader.load('T5', nad=False, unk_idx=2)

model1 = T5BiLSTM(512, 128, 0.5, 1).to(device)
model2 = T5BiLSTM(512, 128, 0.5, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
cot_criterion = CoteachingLoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4)

# coteaching plus training: both criterions
model1.train()
model2.train()
train_res = train_or_eval_coteaching(
    model_name='T5',
    loader=test_loader,
    model1=model1,
    model2=model2,
    device=device,
    forget_rate=0.2,
    cot_criterion=cot_criterion,
    criterion=criterion,
    scaler1=scaler1,
    scaler2=scaler2,
    optim1=optimizer1,
    optim2=optimizer2,
    cot_plus_train=True,
)


model1.eval()
model2.eval()
with torch.no_grad():
    val_res = train_or_eval_coteaching(
        model_name='T5',
        loader=test_loader,
        model1=model1,
        model2=model2,
        device=device,
        forget_rate=0.2,
        criterion=criterion,
        scaler1=scaler1,
        scaler2=scaler2,
    )

train_res, val_res

In [None]:
# jocor
# normal tokenizer
loader = CleavageLoader(train_data, val_data, test_data, encode_text, 512, 10)

In [None]:
train_loader, val_loader, test_loader = loader.load('BiLSTM', nad=False, unk_idx=0)

model1 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
model2 = BiLSTM(21, 150, 128, 256, 64, 0.5, 1).to(device)
jocor_criterion = JoCoRLoss()
optimizer = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()), lr=1e-4)

res_train = train_or_eval_jocor('BiLSTM', train_loader, model1, model2, device, 0.2, jocor_criterion, optim=optimizer)
res_val = train_or_eval_jocor('BiLSTM', train_loader, model1, model2, device, 0.2)

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
loader = CleavageLoader(train_data, val_data, test_data, bbpe1, 512, 10)

train_loader, val_loader, test_loader = loader.load('Padded', nad=False, unk_idx=0)

model1 = BiLSTMPadded(bbpe1.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 1).to(device)
model2 = BiLSTMPadded(bbpe1.get_vocab_size(), 150, 128, 256, 64, 0.5, 1, 1).to(device)
jocor_criterion = JoCoRLoss()
optimizer = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()), lr=1e-4)

res_train = train_or_eval_jocor('Padded', train_loader, model1, model2, device, 0.2, jocor_criterion, optim=optimizer)
res_val = train_or_eval_jocor('Padded', train_loader, model1, model2, device, 0.2)

In [None]:
scaler = torch.cuda.amp.GradScaler()
loader = CleavageLoader(train_data, val_data, test_data, tokenizer_t5, 512, 10)

train_loader, val_loader, test_loader = loader.load('T5', nad=False, unk_idx=0)

model1 = T5BiLSTM(512, 128, 0.5, 1).to(device)
model2 = T5BiLSTM(512, 128, 0.5, 1).to(device)
jocor_criterion = JoCoRLoss()
optimizer = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()), lr=1e-4)

res_train = train_or_eval_jocor('T5', test_loader, model1, model2, device, 0.2, jocor_criterion, optim=optimizer, scaler=scaler)
res_val = train_or_eval_jocor('T5', val_loader, model1, model2, device, 0.2, scaler=scaler)

In [None]:
# train hybrid test
# normal model
loader = CleavageLoader(train_data, val_data, test_data, encode_text, 512, 10)

train_loader, val_loader, test_loader = loader.load('BiLSTM', nad=True, unk_idx=0)
model = BiLSTM(21, 150, 128, 256, 64, 0.5, 2).to(device)

conf = torch.tensor([[0.3770, 0.0820], [0.0492, 0.4918]])
conf_norm = conf / conf.sum(dim=1, keepdim=True)
noisemodel = NoiseAdaptation(theta=conf_norm, k=2, device=device).to(device)
noise_optimizer = torch.optim.Adam(noisemodel.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

model.train()
noisemodel.train()
train_res = train_hybrid_nad('BiLSTM', model, noisemodel, train_loader, optimizer, noise_optimizer, criterion, device)

In [None]:
# train hybrid test
# bbpe
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
loader = CleavageLoader(train_data, val_data, test_data, bbpe50, 512, 10)

train_loader, val_loader, test_loader = loader.load('Padded', nad=True, unk_idx=0)
model = BiLSTMPadded(bbpe50.get_vocab_size(), 150, 128, 256, 64, 0.5, 2, pad_idx=1).to(device)

conf = torch.tensor([[0.3770, 0.0820], [0.0492, 0.4918]])
conf_norm = conf / conf.sum(dim=1, keepdim=True)
noisemodel = NoiseAdaptation(theta=conf_norm, k=2, device=device).to(device)
noise_optimizer = torch.optim.Adam(noisemodel.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

model.train()
noisemodel.train()
train_res = train_hybrid_nad('Padded', model, noisemodel, train_loader, optimizer, noise_optimizer, criterion, device)

In [None]:
# train hybrid test
# t5
loader = CleavageLoader(train_data, val_data, test_data, tokenizer_t5, 512, 10)

train_loader, val_loader, test_loader = loader.load('T5', nad=True, unk_idx=2)
model = T5BiLSTM(512, 128, 0.5, 2).to(device)

scaler = torch.cuda.amp.GradScaler()
conf = torch.tensor([[0.3770, 0.0820], [0.0492, 0.4918]])
conf_norm = conf / conf.sum(dim=1, keepdim=True)
noisemodel = NoiseAdaptation(theta=conf_norm, k=2, device=device).to(device)
noise_optimizer = torch.optim.Adam(noisemodel.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

model.train()
noisemodel.train()
train_res = train_hybrid_nad('T5', model, noisemodel, test_loader, optimizer, noise_optimizer, criterion, device, scaler)

# Whole run_train loop

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
highest_val_auc = 0
reg_auc = 0

In [None]:
if reg_auc <= highest_val_auc:
    print('y')

In [None]:
train_data = read_data('../data/c_train.csv')
val_data = read_data('../data/c_val.csv')
test_data = read_data('../data/c_test.csv')

vocab_base = torch.load('../params/vocab.pt')
encode_text = lambda x: vocab_base(list(x))

# load pre-trained esm2 model and vocab
esm2, vocab = torch.hub.load('facebookresearch/esm:main', 'esm2_t30_150M_UR50D')
tokenizer_esm = vocab.get_batch_converter()

tokenizer_t5 = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
)

# load all tokenizers here
bbpe1_vocab, bbpe1_merges = '../params/c_bbpe1k-vocab.json', '../params/c_bbpe1k-merges.txt'
bbpe50_vocab, bbpe50_merges = '../params/c_bbpe50k-vocab.json', '../params/c_bbpe50k-merges.txt'
wp50_vocab = '../params/c_wp50k-vocab.txt'

bbpe1 = ByteLevelBPETokenizer.from_file(bbpe1_vocab, bbpe1_merges, lowercase=False)
bbpe1.enable_padding(pad_id=1, pad_token='<PAD>')

bbpe50 = ByteLevelBPETokenizer.from_file(bbpe50_vocab, bbpe50_merges, lowercase=False)
bbpe50.enable_padding(pad_id=1, pad_token='<PAD>')

wp50 = BertWordPieceTokenizer(wp50_vocab)
wp50.enable_padding(pad_id=0, pad_token='[PAD]')

In [None]:
rate_schedule = torch.ones(15) * 0.1
rate_schedule

In [None]:
rate_schedule[:10] = torch.linspace(0, 0.1, 10)
rate_schedule

For the main script, we need:
* --device
* --reading in data
    * --with special selection of term
* --os.environ
* --model loading 
    * --tokenizer, depending on model choice
    * --all hyperparams
        * --make the argparse simply take all hyperparams, also those we don't use
    * --check especially for unk and pad token indices
    * --create dicts with params per each model conf 
* K-Fold
* train loop
    * use del for models
    * re-create tokenizer per split in train loop, check loops from BBPE re-make
    * make paths
    * early stopping after 5 epochs overfitting or decreasing val loss
    * include print statements for security (check how to not have empty prints when a metric is missing?), maybe just print the result objects
    * check how to save results in file with fold and varying numbers of items.
        * Check that train and val always return the same number of args, and write util functions to save each kind of denoising technique
* somehow save test results too
* create option for k-fold or not, if not, it should run base train set and eval on val set