# Sources
* BiLSTM model architecture based on [Ozols et. al., 2021](https://www.mdpi.com/1422-0067/22/6/3071/htm)
* DivideMix sources are listed under `cleavage_prediction/denoise/divide_mix/data_handling.py` and `cleavage_prediction/denoise/divide_mix/train_utils.py`

In [1]:
import sys
sys.path.append('../../denoise/')

import os
import math
import random
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.vocab import build_vocab_from_iterator

from divide_mix.data_handling import CleavageLoader
from divide_mix.train_utils import (
    NegEntropy,
    SemiLoss,
    warmup,
    train,
    evaluate,
    process_gmm
)

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
def gelu(x):
    """
    Facebook Research implementation of the gelu activation function.
    
    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def regularized_auc(train_auc, dev_auc, threshold=0.0025):
    """
    Returns development AUC if overfitting is below threshold, otherwise 0.
    """
    return dev_auc if (train_auc - dev_auc) < threshold else 0
            
def trainable_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def total_model_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_size1, rnn_size2, hidden_size, dropout1, dropout2):
        super().__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )
        
        self.dropout1=nn.Dropout(dropout1)
        self.dropout2=nn.Dropout(dropout2)
        
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=rnn_size1,
            bidirectional=True,
            batch_first=True,
        )

        self.lstm2 = nn.LSTM(
            input_size=2 * rnn_size1,
            hidden_size=rnn_size2,
            bidirectional=True,
            batch_first=True,
        )
        
        self.fc1 = nn.Linear(rnn_size2 * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 2)
        
    
    def no_embed_fw(self, embedded):
        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm1(embedded)
        
        # input shape: (batch_size, seq_len, 2*rnn_size1)
        out, _ = self.lstm2(out)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled, _ = torch.max(out, dim=1)
        
        # input shape: (batch_size, 2*hidden_size)
        out = self.dropout1(gelu(self.fc1(pooled)))
        
        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size, 2)
        return self.fc2(out)
    
    def forward(self, seq, seq2=None, lam=None, interpolate=False):
        if interpolate:
            # input shape: (batch_size, seq_len=10)
            embedded1 = self.embedding(seq)
            embedded2 = self.embedding(seq2)
            embedded_mixed = lam * embedded1 + (1 - lam) * embedded2
            return self.no_embed_fw(self.dropout2(embedded_mixed))
        else:
            # input shape: (batch_size, seq_len=10)
            embedded = self.dropout1(self.embedding(seq))
            return self.no_embed_fw(embedded)

In [5]:
BATCH_SIZE = 512
LR = 3e-4
ALPHA = 0.5
LAMBDA_U = 0
P_THRESHOLD = 0.5
TEMPERATURE = 0.5
NUM_EPOCHS = 15
NUM_WARM_UP_EPOCHS = 1
RAMPUP_LEN = 5
NUM_CLASSES = 2

vocab = torch.load('../../data/vocab.pt')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
beta_dist = torch.distributions.beta.Beta(ALPHA, ALPHA)
loader = CleavageLoader(batch_size=BATCH_SIZE, num_workers=10)

In [6]:
model1 = BiLSTM(
    vocab_size=len(vocab),
    embedding_dim=91,
    rnn_size1=228,
    rnn_size2=506,
    hidden_size=164,
    dropout1=0.5,
    dropout2=0.
).to(device)

model2 = BiLSTM(
    vocab_size=len(vocab),
    embedding_dim=91,
    rnn_size1=228,
    rnn_size2=506,
    hidden_size=164,
    dropout1=0.5,
    dropout2=0.
).to(device)

optimizer1 = optim.Adam(model1.parameters(), lr=LR)
optimizer2 = optim.Adam(model2.parameters(), lr=LR)

criterion = SemiLoss()
conf_penalty = NegEntropy()
CEloss = nn.CrossEntropyLoss()
CE = nn.CrossEntropyLoss(reduction="none")

In [7]:
warmup_loader = loader.load(terminus="c", mode="warmup")
train_gmm_loader = loader.load(terminus="c", mode="divide_by_GMM")
eval_loader = loader.load(terminus="c", mode="evaluate")

start = time()
highest_val_auc = 0
for epoch in range(1, NUM_EPOCHS + 1):
    if epoch < NUM_WARM_UP_EPOCHS + 1:
        # run warm up model 1 and 2 while adding penalty for confident predictions
        warmup_loss1, warmup_acc1, warmup_auc1 = warmup(
            model=model1,
            optimizer=optimizer1,
            loss_func=CEloss,
            conf_penalty=conf_penalty,
            dataloader=warmup_loader
        )
        warmup_loss2, warmup_acc2, warmup_auc2 = warmup(
            model=model2,
            optimizer=optimizer2,
            loss_func=CEloss,
            conf_penalty=conf_penalty,
            dataloader=warmup_loader
        )

        print(
            f"Warm-Up Model1: [Epoch {epoch:2d}, Loss: {warmup_loss1:8.6f}, Acc: {warmup_acc1:.4f}, AUC: {warmup_auc1:.4f}]"
        )
        print(
            f"Warm-Up Model2: [Epoch {epoch:2d}, Loss: {warmup_loss2:8.6f}, Acc: {warmup_acc2:.4f}, AUC: {warmup_auc2:.4f}]"
        )
        
        # evaluate on dev set
        val_acc, val_auc, val_loss = evaluate(
            model1=model1,
            model2=model2,
            loss_func=CE,
            dataloader=eval_loader
        )
        print(
            f"Evaluation Set: [Epoch {epoch:2d}, Loss: {val_loss:.6f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]"
        )

    else:
        prob1, train_loss1, raw_losses1, norm_losses1 = process_gmm(model1, train_gmm_loader, CE)
        prob2, train_loss2, raw_losses2, norm_losses2 = process_gmm(model2, train_gmm_loader, CE)

        pred1 = prob1 > P_THRESHOLD
        pred2 = prob2 > P_THRESHOLD
        
        # train both models
        labeled_trainloader, unlabeled_trainloader = loader.load(
            terminus="c", mode="train", pred=pred2, prob=prob2
        )
        
        divmix_loss1 = train(
            epoch=epoch,
            model1=model1,
            model2=model2,
            optimizer=optimizer1,
            loss_func=criterion,
            num_warm_up_epochs=NUM_WARM_UP_EPOCHS,
            num_classes=NUM_CLASSES,
            lambda_u=LAMBDA_U,
            temp=TEMPERATURE,
            beta_dist=beta_dist,
            batch_size=BATCH_SIZE,
            labeled_loader=labeled_trainloader,
            unlabeled_loader=unlabeled_trainloader,
            rampup_len=RAMPUP_LEN,
            named_model="model1"
        )

        labeled_trainloader, unlabeled_trainloader = loader.load(
            terminus="c", mode="train", pred=pred1, prob=prob1
        )
        
        
        divmix_loss2 = train(
            epoch=epoch,
            model1=model2,
            model2=model1,
            optimizer=optimizer2,
            loss_func=criterion,
            num_warm_up_epochs=NUM_WARM_UP_EPOCHS,
            num_classes=NUM_CLASSES,
            lambda_u=LAMBDA_U,
            temp=TEMPERATURE,
            beta_dist=beta_dist,
            batch_size=BATCH_SIZE,
            labeled_loader=labeled_trainloader,
            unlabeled_loader=unlabeled_trainloader,
            rampup_len=RAMPUP_LEN,
            named_model="model2"
        )

        # evaluate on dev set
        val_acc, val_auc, val_loss = evaluate(
            model1=model1,
            model2=model2,
            loss_func=CE,
            dataloader=eval_loader
        )

        print(
            f"Training Set: [Epoch {epoch:2d}, Loss1: {train_loss1:.6f}, Loss2: {train_loss2:.6f}]"
        )
        print(
            f"DivideMix Training: [Epoch {epoch:2d}, Loss1: {divmix_loss1:.6f}, Loss2: {divmix_loss2:.6f}]"
        )
        print(
            f"Evaluation Set: [Epoch {epoch:2d}, Loss: {val_loss:.6f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]"
        )
        
        if val_auc > highest_val_auc:
            highest_val_auc = val_auc
            path1 = f"../../params/c_term/BiLSTM_dividemix/auc{val_auc:.4f}_epoch{epoch}_model1.pt"
            path2 = f"../../params/c_term/BiLSTM_dividemix/model2_epoch{epoch}.pt"
            torch.save(model1.state_dict(), path1)
            torch.save(model2.state_dict(), path2)


print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minute.")

warmup: 100%|██████████████████████████████████████████████████| 2218/2218 [00:20<00:00, 109.07batches/s]
warmup: 100%|██████████████████████████████████████████████████| 2218/2218 [00:20<00:00, 110.18batches/s]
Warm-Up Model1: [Epoch  1, Loss: -0.000282, Acc: 0.8436, AUC: 0.7955]
Warm-Up Model2: [Epoch  1, Loss: -0.000279, Acc: 0.8423, AUC: 0.7882]
evaluate: 100%|██████████████████████████████████████████████████| 278/278 [00:01<00:00, 167.52batches/s]
Evaluation Set: [Epoch  1, Loss: 0.331895, Acc: 0.8623, AUC: 0.8714]
GMM processing: 100%|██████████████████████████████████████████| 2218/2218 [00:05<00:00, 379.23batches/s]
GMM processing: 100%|██████████████████████████████████████████| 2218/2218 [00:05<00:00, 371.59batches/s]
train model1: 100%|█████████████████████████████████████████████| 1638/1638 [01:10<00:00, 23.09batches/s]
train model2: 100%|█████████████████████████████████████████████| 1646/1646 [01:11<00:00, 22.95batches/s]
evaluate: 100%|██████████████████████████████████

In [8]:
# load best model, evaluate on test set
test_loader = loader.load(terminus="c", mode="test")

best_model1 = sorted(
    [f for f in os.listdir("../../params/c_term/BiLSTM_dividemix/") if f.endswith("model1.pt")],
    reverse=True,
)[0]
print("Loaded model1: ", best_model1)

best_model2 = 'model2_' + best_model1.split('_')[1] + '.pt'
print("Loaded model2: ", best_model2)

model1.load_state_dict(torch.load('../../params/c_term/BiLSTM_dividemix/' + best_model1))
model2.load_state_dict(torch.load('../../params/c_term/BiLSTM_dividemix/' + best_model2))
model1.eval()
model2.eval()

# evaluate on dev set
test_acc, test_auc, test_loss = evaluate(
    model1=model1,
    model2=model2,
    loss_func=CE,
    dataloader=test_loader
)
print(
    f"Test Set Performance: Loss: {test_loss:.6f}, Acc: {test_acc:.4f}, AUC: {test_auc:.4f}"
)
print(
    f"Total model params: {total_model_params(model1)}, trainable model params: {trainable_model_params(model1)}"
)

Loaded model1:  auc0.8607_epoch12_model1.pt
Loaded model2:  model2_epoch12.pt
evaluate: 100%|██████████████████████████████████████████████████| 278/278 [00:01<00:00, 166.80batches/s]
Test Set Performance: Loss: 1.407796, Acc: 0.8402, AUC: 0.8625
Total model params: 4656149, trainable model params: 4656149
