In [1]:
import sys
sys.path.append('../denoise/')

import os
import csv
import pickle
import random
import math
import numpy as np
from time import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.parametrizations import spectral_norm

from sklearn.metrics import roc_auc_score
from sklearn.mixture import GaussianMixture

from divide_mix.data_handling import CleavageLoader

In [15]:
class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_size1, rnn_size2, hidden_size, dropout1, dropout2):
        super().__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )
        
        self.dropout1=nn.Dropout(dropout1)
        self.dropout2=nn.Dropout(dropout2)
        
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=rnn_size1,
            bidirectional=True,
            batch_first=True,
        )

        self.lstm2 = nn.LSTM(
            input_size=2 * rnn_size1,
            hidden_size=rnn_size2,
            bidirectional=True,
            batch_first=True,
        )
        
        self.fc1 = spectral_norm(nn.Linear(rnn_size2 * 2, hidden_size))
        self.fc2 = nn.Linear(hidden_size, 2)
        
    
    def no_embed_fw(self, embedded):
        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm1(embedded)
        
        # input shape: (batch_size, seq_len, 2*rnn_size1)
        out, _ = self.lstm2(out)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled, _ = torch.max(out, dim=1)
        
        # input shape: (batch_size, 2*hidden_size)
        out = self.dropout1(gelu(self.fc1(pooled)))
        
        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size, 2)
        return self.fc2(out)
    
    def forward(self, seq, seq2=None, lam=None, interpolate=False):
        if interpolate:
            # input shape: (batch_size, seq_len=10)
            embedded1 = self.embedding(seq)
            embedded2 = self.embedding(seq2)
            embedded_mixed = lam * embedded1 + (1 - lam) * embedded2
            return self.no_embed_fw(self.dropout2(embedded_mixed))
        else:
            # input shape: (batch_size, seq_len=10)
            embedded = self.dropout1(self.embedding(seq))
            return self.no_embed_fw(embedded)

In [3]:
def train(epoch, model1, model2, optimizer, labeled_loader, unlabeled_loader, named_model):
    epoch_loss, total = 0, 0
    
    # fix one model while the other one trains
    model1.train()
    model2.eval()
    
    unlabeled_train_iter = iter(unlabeled_loader)
    num_iter = (len(labeled_loader.dataset) // BATCH_SIZE) + 1
    
    for batch_idx, batch in enumerate(
        tqdm(
        labeled_loader,
        desc=f"train {named_model}: ",
        file=sys.stdout,
        unit="batches"
        )
    ):
        in_x, in_x2, labels_x, w_x = batch.seq1, batch.seq2, batch.lbl, batch.prob
        batch_size = in_x.shape[0]
        
        # re-use unlabeled_loader as long as there are batches in labeled_loader
        try:
            unlabeled_batch = unlabeled_train_iter.next()
            in_u, in_u2 = unlabeled_batch.seq1, unlabeled_batch.seq2
        except:
            unlabeled_train_iter = iter(unlabeled_loader)
            unlabeled_batch = unlabeled_train_iter.next()
            in_u, in_u2 = unlabeled_batch.seq1, unlabeled_batch.seq2
        
        
        w_x = w_x.view(-1, 1).to(torch.float)
        in_x, in_x2, labels_x, w_x = (
            in_x.to(device),
            in_x2.to(device),
            labels_x.to(device),
            w_x.to(device)
        )
        
        in_u, in_u2 = in_u.to(device), in_u2.to(device)
        
        with torch.no_grad():
            # for labeled samples: co-refinement + temperature sharpening
            out_x = model1(in_x)
            out_x2 = model1(in_x2)

            px = (torch.softmax(out_x, dim=1) + torch.softmax(out_x2, dim=1)) / 2
            px = w_x * labels_x + (1 - w_x) * px
                
            ptx = px ** (1 / TEMPERATURE)
            targets_x = (ptx / ptx.sum(dim=1, keepdim=True)).detach()
            
            # for unlabeled samples: co-guessing + temperature sharpening
            out_u11 = model1(in_u)
            out_u12 = model1(in_u2)
            out_u21 = model2(in_u)
            out_u22 = model2(in_u2)
            
            pu = (
                torch.softmax(out_u11, dim=1)
                + torch.softmax(out_u12, dim=1)
                + torch.softmax(out_u21, dim=1)
                + torch.softmax(out_u22, dim=1)
            ) / 4
            
            ptu = pu ** (1 / TEMPERATURE)
            targets_u = (ptu / ptu.sum(dim=1, keepdim=True)).detach()
            
        ### MixMatch
        # lambda interpolation factor for mixed input and targets
        lam = beta_dist.sample()
        lam = max(lam, 1-lam)
        
        all_ins = torch.cat([in_x, in_x2, in_u, in_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
        
        # shuffle all inputs to generate new pairs
        idx = torch.randperm(all_ins.shape[0])

        in_a, in_b = all_ins, all_ins[idx]
        target_a, target_b = all_targets, all_targets[idx]
        
        mixed_target = lam * target_a + (1 - lam) * target_b
        # inputs are mixed up in forward pass
        logits = model1(in_a, in_b, lam, interpolate=True)
        logits_x, logits_u = logits[:batch_size*2], logits[batch_size*2:]
        
        Lx, Lu, lambda_u = criterion(
            logits_x,
            mixed_target[:batch_size*2],
            logits_u,
            mixed_target[batch_size*2:],
            epoch + batch_idx/num_iter,
            NUM_WARM_UP_EPOCHS
        )
        
        # regularization
        prior = (torch.ones(NUM_CLASSES) / NUM_CLASSES).to(device)
        pred_mean = torch.softmax(logits, dim=1).mean(0)
        penalty = torch.sum(prior * torch.log(prior / pred_mean))
        
        loss = Lx + lambda_u * Lu + penalty
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        total += idx.shape[0]
    return epoch_loss / total

In [4]:
def warmup(epoch, model, optimizer, dataloader):
    epoch_loss, num_correct, total = 0, 0, 0
    preds, lbls = [], []
    model.train()

    for batch_idx, batch in enumerate(
        tqdm(dataloader, desc="warmup: ", file=sys.stdout, unit="batches")
    ):
        in_x, lbl_x = batch.seq, batch.lbl
        in_x, lbl_x = in_x.to(device), lbl_x.to(device)

        out = model(in_x)
        loss = CEloss(out, lbl_x)

        # penalty for confident predictions for asymmetric noise
        loss += conf_penalty(out)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        non_ohe_labels = lbl_x.argmax(dim=1)
        num_correct += (
            (out.argmax(dim=1) == non_ohe_labels).sum().item()
        )
        total += in_x.shape[0]
        preds.extend(out[:, 1].detach().tolist())
        lbls.extend(non_ohe_labels.detach().tolist())
    return epoch_loss / total, num_correct / total, roc_auc_score(lbls, preds)

In [5]:
def evaluate(model1, model2, dataloader):
    num_correct, total = 0, 0
    preds, lbls, losses = [], [], []
    model1.eval()
    model2.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="evaluate: ", file=sys.stdout, unit="batches"):
            in_x, lbl_x = batch.seq, batch.lbl
            in_x, lbl_x = in_x.to(device), lbl_x.to(device)
            out1 = model1(in_x)
            out2 = model2(in_x)
            out = out1 + out2

            loss = CE(out, lbl_x)
            losses.extend(loss.detach().tolist())
            
            non_ohe_labels = lbl_x.argmax(dim=1)
            num_correct += (
                (out.argmax(dim=1) == non_ohe_labels).sum().item()
            )
            total += in_x.shape[0]
            preds.extend(out[:, 1].detach().tolist())
            lbls.extend(non_ohe_labels.detach().tolist())
            
    losses = np.array(losses)
    return num_correct / total, roc_auc_score(lbls, preds), losses.mean()

In [6]:
def process_gmm(model, dataloader):
    losses = []
    model.eval()

    with torch.no_grad():
        for batch in tqdm(
            dataloader, desc="GMM processing: ", file=sys.stdout, unit="batches"
        ):
            in_x, lbl_x = batch.seq, batch.lbl
            in_x, lbl_x = in_x.to(device), lbl_x.to(device)

            out = model(in_x)
            # nn.CrossEntropyLoss with reduction=None, returns per sample loss 
            loss = CE(out, lbl_x)  
            losses.extend(loss.detach().tolist())

    losses = np.array(losses)
    # normalize losses between 0 and 1
    norm_losses = ((losses - losses.min()) / losses.ptp())[:, np.newaxis]

    # fit two component GMM to loss
    gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
    gmm.fit(norm_losses)
    prob = gmm.predict_proba(norm_losses)
    # get value of smaller mean dist
    prob = prob[:, gmm.means_.argmin()]
    # out shape: (batch_size)
    return prob, losses.mean()

In [7]:
def linear_rampup(current_epoch, warm_up, rampup_len=15):
    current = np.clip((current_epoch - warm_up) / rampup_len, 0.0, 1.0)
    return LAMBDA_U * current

In [8]:
class SemiLoss:
    def __call__(self, out_x, lbl_x, out_u, lbl_u, epoch, warm_up):
        probs_u = torch.softmax(out_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(out_x, dim=1) * lbl_x, dim=1))
        Lu = torch.mean((probs_u - lbl_u) ** 2)

        return Lx, Lu, linear_rampup(epoch, warm_up)

In [9]:
class NegEntropy:
    def __call__(self, out):
        probs = torch.softmax(out, dim=1)
        return torch.mean(torch.sum(probs.log() * probs, dim=1))

In [10]:
# argparse params
BATCH_SIZE = 64
LR = 5e-5
ALPHA = 0.5
LAMBDA_U = 0
P_THRESHOLD = 0.5
TEMPERATURE = 0.5
NUM_EPOCHS = 50
NUM_WARM_UP_EPOCHS = 10
NUM_CLASSES = 2



device = "cuda:0"

def gelu(x):
    """
    Facebook Research implementation of the gelu activation function.
    
    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [11]:
### MAIN
vocab = torch.load('../data/vocab.pt')
beta_dist = torch.distributions.beta.Beta(ALPHA, ALPHA)
loader = CleavageLoader(batch_size=BATCH_SIZE, num_workers=10)

In [16]:
model1 = BiLSTM(
    vocab_size=len(vocab),
    embedding_dim=200,
    rnn_size1=256,
    rnn_size2=512,
    hidden_size=128,
    dropout1=0.5,
    dropout2=0.
).to(device)

model2 = BiLSTM(
    vocab_size=len(vocab),
    embedding_dim=20,
    rnn_size1=256,
    rnn_size2=512,
    hidden_size=128,
    dropout1=0.5,
    dropout2=0.
).to(device)

In [17]:
criterion = SemiLoss()
optimizer1 = optim.Adam(model1.parameters(), lr=LR)
optimizer2 = optim.Adam(model2.parameters(), lr=LR)
CE = nn.CrossEntropyLoss(reduction="none")
CEloss = nn.CrossEntropyLoss()
conf_penalty = NegEntropy()

In [18]:
warmup_loader = loader.load(terminus="n", mode="warmup")
train_gmm_loader = loader.load(terminus="n", mode="divide_by_GMM")
eval_loader = loader.load(terminus="n", mode="evaluate")

start = time()
highest_val_auc = 0
for epoch in range(1, NUM_EPOCHS + 1):
    if epoch < NUM_WARM_UP_EPOCHS + 1:
        # run warm up model 1 and 2 while adding penalty for confident predictions
        warmup_loss1, warmup_acc1, warmup_auc1 = warmup(
            epoch, model1, optimizer1, warmup_loader
        )
        warmup_loss2, warmup_acc2, warmup_auc2 = warmup(
            epoch, model2, optimizer2, warmup_loader
        )

        print(
            f"Warm-Up Model1: [Epoch {epoch:2d}, Loss: {warmup_loss1:8.6f}, Acc: {warmup_acc1:.4f}, AUC: {warmup_auc1:.4f}]"
        )
        print(
            f"Warm-Up Model2: [Epoch {epoch:2d}, Loss: {warmup_loss2:8.6f}, Acc: {warmup_acc2:.4f}, AUC: {warmup_auc2:.4f}]"
        )
        
        # evaluate on dev set
        val_acc, val_auc, val_loss = evaluate(model1, model2, eval_loader)
        print(
            f"Evaluation Set: [Epoch {epoch:2d}, Loss: {val_loss:.6f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]"
        )

    else:
        
        prob1, train_loss1 = process_gmm(model1, train_gmm_loader)
        prob2, train_loss2 = process_gmm(model2, train_gmm_loader)

        pred1 = prob1 > P_THRESHOLD
        pred2 = prob2 > P_THRESHOLD

        # train both models
        labeled_trainloader, unlabeled_trainloader = loader.load(
            terminus="n", mode="train", pred=pred2, prob=prob2
        )
        divmix_loss1 = train(
            epoch,
            model1,
            model2,
            optimizer1,
            labeled_trainloader,
            unlabeled_trainloader,
            "model1"
        )

        labeled_trainloader, unlabeled_trainloader = loader.load(
            terminus="n", mode="train", pred=pred1, prob=prob1
        )
        divmix_loss2 = train(
            epoch,
            model2,
            model1,
            optimizer2,
            labeled_trainloader,
            unlabeled_trainloader,
            "model2"
        )

        # evaluate on dev set
        val_acc, val_auc, val_loss = evaluate(model1, model2, eval_loader)

        # if val_auc > highest_val_auc:
        #     highest_val_auc = val_auc
        #     path1 = f"../params/n_term/BiLSTM/auc{val_auc:.4f}_epoch{epoch}_model1.pt"
        #     path2 = f"../params/n_term/BiLSTM/auc{val_auc:.4f}_epoch{epoch}_model2.pt"
        #     torch.save(model1.state_dict(), path1)
        #     torch.save(model2.state_dict(), path2)

        print(
            f"Training Set: [Epoch {epoch:2d}, Loss1: {train_loss1:.6f}, Loss2: {train_loss2:.6f}]"
        )
        print(
            f"DivideMix Training: [Epoch {epoch:2d}, Loss1: {divmix_loss1:.6f}, Loss2: {divmix_loss2:.6f}]"
        )
        print(
            f"Evaluation Set: [Epoch {epoch:2d}, Loss: {val_loss:.6f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}]"
        )


print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minute.")

warmup: 100%|██████████████████████████████████████████████████| 2236/2236 [00:19<00:00, 114.46batches/s]
warmup: 100%|██████████████████████████████████████████████████| 2236/2236 [00:18<00:00, 118.95batches/s]
Warm-Up Model1: [Epoch  1, Loss: -0.000036, Acc: 0.6027, AUC: 0.6161]
Warm-Up Model2: [Epoch  1, Loss: -0.000029, Acc: 0.5936, AUC: 0.6005]
evaluate: 100%|███████████████████████████████████████████████████| 140/140 [00:01<00:00, 91.94batches/s]
Evaluation Set: [Epoch  1, Loss: 0.590031, Acc: 0.6800, AUC: 0.7453]
warmup: 100%|██████████████████████████████████████████████████| 2236/2236 [00:19<00:00, 113.72batches/s]
warmup: 100%|██████████████████████████████████████████████████| 2236/2236 [00:18<00:00, 118.85batches/s]
Warm-Up Model1: [Epoch  2, Loss: -0.000096, Acc: 0.6749, AUC: 0.7324]
Warm-Up Model2: [Epoch  2, Loss: -0.000071, Acc: 0.6445, AUC: 0.6962]
evaluate: 100%|███████████████████████████████████████████████████| 140/140 [00:01<00:00, 91.40batches/s]
Evaluation Set: