In [1]:
from transformers import (
    BertForSequenceClassification,
    BertTokenizerFast,
    DistilBertForSequenceClassification,
    DistilBertTokenizerFast,
    DistilBertModel
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import inspect
device=torch.cuda.current_device()
from EDA.augment import gen_eda
from Utils.data_loader import get_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd

In [3]:
import time

In [4]:
torch.cuda.is_available()

True

In [5]:
pretrained_model_name='distilbert-base-uncased'

In [7]:
train_dataloader = get_dataloader('Datasets/IMDB_500_sentiment.csv',batch_sizes=[16])[0]
validation_dataloader=get_dataloader('Datasets/IMDB_1000_ssmba_val.csv')[0]
ood_dataloader=get_dataloader('Datasets/SST-2_1000_ssmba_test.csv')[0]

In [8]:
df=pd.read_csv('Datasets/IMDB_Full.csv',names=['labels','text'])

In [21]:
tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_model_name)

In [23]:
from transformers import logging
logging.set_verbosity_warning()
logging.set_verbosity_error()

In [24]:
class emixBERTClassifier(torch.nn.Module):
    def __init__(self,model_name,num_labels,alpha=1,mixing=False,device="cuda"):
        super().__init__()
        self.num_labels=num_labels
        self.device=device
        self.model_name=model_name
        self.dbert=DistilBertForSequenceClassification.from_pretrained(self.model_name,num_labels=self.num_labels).to(self.device)
        #self.tokenizer=DistilBertTokenizerFast.from_pretrained(self.model_name)
        self.dropout=torch.nn.Dropout(0.1).to(self.device)
        self.mixup_layers=torch.arange(1,self.dbert.distilbert.transformer.n_layers)
        self.alpha=alpha
        self.mixing=mixing

    def get_mixing_ratio(self,std1,std2):
        lam=torch.distributions.beta.Beta(self.alpha,self.alpha).sample()
        t=1/(1+(std1/std2)*((1-lam)/lam))
        return t.to(self.device)

    def emix(self,h1,h2,a1,a2,t):
        mixed_representation=(t*h1+(1-t)*h2)/torch.sqrt(t**2+(1-t)**2)
        ## Original paper gives no information on handling attention masks
        ## Assumption is made that the the "and" of attention mask is taken to avoid missing any data
        mixed_attention=torch.max(a1,a2)
        return mixed_representation.to(self.device),mixed_attention.to(self.device)
        
    def do_emix(self,option=None):
        if option==None:
            self.mixing= not self.mixing
        else:
            self.mixing=option

    def forward(self,input_ids,attention_mask):
        emb=self.dbert.distilbert.embeddings(input_ids)
        
        if self.mixing:
            mixing_layer=np.random.choice(self.mixup_layers,size=1)[0]
        else:
            mixing_layer=0
        hidden_rep=self.dbert.distilbert.transformer.layer[0](emb,attention_mask)[0]
        t_list=[]
        for layer_idx in range(1,self.dbert.distilbert.transformer.n_layers):
            if layer_idx==mixing_layer:
                mixed_states=[]
                mixed_masks=[]
                for i in range(input_ids.shape[0]//2):
                    h_i=hidden_rep[i*2]
                    h_j=hidden_rep[i*2+1]
                    a_i=attention_mask[i*2]
                    a_j=attention_mask[i*2+1]
                    std_i=torch.std(h_i)
                    std_j=torch.std(h_j)
                    t=self.get_mixing_ratio(std_i,std_j)
                    mixed_state,mixed_mask=self.emix(h_i,h_j,a_i,a_j,t)
                    t_list.append(t)
                    mixed_states.append(mixed_state)
                    mixed_masks.append(mixed_mask)
                hidden_rep=torch.stack(mixed_states).to(self.device)
                attention_mask=torch.stack(mixed_masks).to(self.device)

            # Performs dropout by default
            hidden_rep=self.dbert.distilbert.transformer.layer[layer_idx](hidden_rep,attention_mask)[0]

        pooled_output = hidden_rep[:, 0]
        pooled_output=self.dbert.pre_classifier(pooled_output)
        pooled_output = nn.ReLU()(pooled_output).to(self.device)  # (bs, dim)
        pooled_output = self.dbert.dropout(pooled_output)  # (bs, dim)
        logits=self.dbert.classifier(pooled_output)
        return logits.to(self.device),t_list
        

In [25]:
class emixRNNClassifier(torch.nn.Module):
    def __init__(self,model_name,num_labels,vocab_size,model_type='GRU',hidden_size=128,embedding_dim=256,alpha=1,mixing=False,device="cuda"):
        super().__init__()
        self.num_labels=num_labels
        self.device=device
        self.model_name=model_name
        self.uses_attention=False
        self.vocab_size=vocab_size
        self.model_type=model_type
        self.num_layers=3
        self.embedding=nn.Embedding(vocab_size,embedding_dim).to(device)
        self.rnns=nn.ModuleList()
        if self.model_type=='GRU':
            self.rnns=nn.ModuleList()
            for i in range(self.num_layers):
                input_size=embedding_dim if i==0 else hidden_size
                self.rnns.append(nn.GRU(input_size,hidden_size,num_layers=1).to(device))
        else:
            for i in range(self.num_layers):
                input_size=embedding_dim if i==0 else hidden_size
                self.rnns.append(nn.LSTM(input_size,hidden_size,num_layers=1).to(device))
        self.dropout=torch.nn.Dropout(0.1).to(self.device)
        self.mixup_layers=torch.arange(1,self.num_layers)
        self.alpha=alpha
        self.mixing=mixing
        self.fc=nn.Linear(hidden_size,num_labels).to(device)

    def get_mixing_ratio(self,std1,std2):
        lam=torch.distributions.beta.Beta(self.alpha,self.alpha).sample()
        t=1/(1+(std1/std2)*((1-lam)/lam))
        return t.to(self.device)

    def emix(self,h1,h2,t):
        mixed_representation=(t*h1+(1-t)*h2)/torch.sqrt(t**2+(1-t)**2)
        
        return mixed_representation.to(self.device)
        
    def do_emix(self,option=None):
        if option==None:
            self.mixing= not self.mixing
        else:
            self.mixing=option

    def forward(self,input_ids,attention_mask):
        emb=self.embedding(input_ids)
        if self.mixing:
            mixing_layer=np.random.choice(self.mixup_layers,size=1)[0]
        else:
            mixing_layer=0
        
        hidden_rep=self.rnns[0](emb)[0]
        t_list=[]
        for layer_idx in range(1,self.num_layers):
            if layer_idx==mixing_layer:
                mixed_states=[]
                mixed_masks=[]
                for i in range(input_ids.shape[0]//2):
                    h_i=hidden_rep[i*2]
                    h_j=hidden_rep[i*2+1]
                    std_i=torch.std(h_i)
                    std_j=torch.std(h_j)
                    t=self.get_mixing_ratio(std_i,std_j)
                    mixed_state=self.emix(h_i,h_j,t)
                    t_list.append(t)
                    mixed_states.append(mixed_state)
                hidden_rep=torch.stack(mixed_states).to(self.device)
            hidden_rep=self.rnns[layer_idx](hidden_rep)[0]
            hidden_rep=self.dropout(hidden_rep)

        final_state = hidden_rep[:,-1,:]
        logits=self.fc(final_state)
        return logits.to(self.device),t_list

In [26]:
class emixTrainer():
    def __init__(self,tokenizer,model,device,criterion,optimizer,train_dataloader,ood_dataloader,max_length=512,
                 epochs=1,validation_dataloader=None,mixing=True):
        self.model=model
        self.tokenizer=tokenizer
        self.device=device
        self.criterion=criterion
        self.optimizer=optimizer
        self.ood_dataloader=ood_dataloader
        self.train_dataloader=train_dataloader
        self.validation_dataloader=validation_dataloader
        self.epochs=epochs
        self.mixing=mixing
        self.max_length=max_length
    
    def mixup_criterion(self,preds,labels,t_list):
        a_labs=[]
        b_labs=[]
        for i in range(len(preds)):
            a_labs.append(labels[i*2])
            b_labs.append(labels[2*i+1])
        a_labs=torch.tensor(a_labs).to(self.device)
        b_labs=torch.tensor(b_labs).to(self.device)
        t=torch.tensor(t_list).to(self.device)
        loss=(t*self.criterion(preds,a_labs)+(1-t)*self.criterion(preds,b_labs))/torch.sqrt(t**2+(1-t)**2)
        loss=torch.mean(loss)
        return loss
               
    def train(self,mix=True):
        max_val_acc=0
        best_epoch=-1
        best_ood=0
        for epoch in range(self.epochs):
            epoch_loss=0
            for text,labels in self.train_dataloader:
                
                inputs=tokenizer(text,padding='max_length', truncation=True, return_tensors='pt',
                                        max_length=self.max_length)
                
                input_ids=inputs['input_ids'].to(self.device)
                attention_mask=inputs['attention_mask'].to(self.device)
                labels=labels.to(self.device)
                ## Without mixing
                self.optimizer.zero_grad()
                self.model.do_emix(False)
                preds,_=self.model(input_ids,attention_mask)
                loss=self.criterion(preds,labels).to(self.device)
                epoch_loss+=loss.item()
                loss.backward()
                self.optimizer.step()
                if mix:
                    ## With mixing
                    self.optimizer.zero_grad()
                    self.model.do_emix(True)
                    preds,t_list=self.model(input_ids,attention_mask)
                    loss=self.mixup_criterion(preds,labels,t_list).to(self.device)
                    epoch_loss+=loss.item()
                    loss.backward()
                    self.optimizer.step()

            if mix:    
                avg_train_loss=epoch_loss/(len(self.train_dataloader)*1.5)
            else:
                avg_train_loss=epoch_loss/(len(self.train_dataloader))

            if validation_dataloader is not None:
                val_acc=self.evaluate(self.validation_dataloader)
                if val_acc>max_val_acc:
                    max_val_acc=val_acc
                    best_epoch=epoch+1
            else:
                val_acc=0
            train_acc=self.evaluate(self.train_dataloader)
            
            if (epoch+1)%5==0:
                print(f'Epoch {epoch + 1}/{self.epochs}, Train Loss: {avg_train_loss}, Train Accuracy:{train_acc} Validation Accuracy: {val_acc}')
        ood_acc=self.evaluate(self.ood_dataloader)
        print(f"Ood accuracy:{ood_acc}")
        return max_val_acc,ood_acc,best_epoch
        
    def evaluate(self,dataloader):
        total_correct = 0
        total_examples = 0
        with torch.no_grad():
            for batch in dataloader:
                texts, labels = batch
                inputs = self.tokenizer(texts, padding='max_length', truncation=True, return_tensors='pt',
                                        max_length=self.max_length)
                input_ids = inputs["input_ids"].to(self.device)
                labels = labels.to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                self.model.do_emix(False)
                outputs,t_list = self.model(input_ids, attention_mask=attention_mask)
                predictions = torch.argmax(outputs, dim=-1)
                correct = (predictions == labels).sum().item()
                total_correct += correct
                total_examples += labels.size(0)
            return total_correct / total_examples


In [27]:
# trainer=emixTrainer(tokenizer,emixBert,device="cuda",criterion=criterion,optimizer=optimizer,train_dataloader=train_dataloader,validation_dataloader=validation_dataloader,epochs=6)

In [28]:
# class emixRNNClassifier(torch.nn.Module):
#    def __init__(self,model_name,num_labels,vocab_size,model_type='GRU',hidden_size=128,embedding_dim=256,alpha=1,mixing=False,device="cuda"):

### BERT Testing

In [37]:
def run_mixed_BERT_tests(train_dataloader,validation_dataloader,ood_dataloader,max_length,num_labels=2,epochs=10,mix=True,num_runs=5):
    max_vals=[]
    lr=3e-5
    oods=[]
    times=[]
    model="BERT"
    condition="no_bias"
    for i in range(0,num_runs):
        emixBert=emixBERTClassifier(pretrained_model_name,num_labels,device='cuda')
        emixBert=emixBert.to('cuda')
        optimizer=torch.optim.AdamW(emixBert.parameters(),lr=lr)
        criterion=nn.CrossEntropyLoss()
        trainer=emixTrainer(tokenizer,emixBert,device="cuda",criterion=criterion,optimizer=optimizer,
                            train_dataloader=train_dataloader,ood_dataloader=ood_dataloader,
                            validation_dataloader=validation_dataloader,epochs=epochs,max_length=max_length)
        
        start_time=time.time()
        max_val,ood_acc,best_epoch=trainer.train(mix=mix)
        end_time=time.time()
        print("-"*50)
        print(f"{i+1} model: Max validation={max_val}, Best Epoch={best_epoch}, ood_acc={ood_acc}, time taken={end_time-start_time}")
        print("-"*50)
        times.append(end_time-start_time)
        max_vals.append(max_val)
        oods.append(ood_acc)
        del emixBert
        del optimizer
        del criterion
    torch.cuda.memory_allocated()
    torch.cuda.empty_cache()
    return np.mean(max_vals),np.std(max_vals),np.mean(times)


### GRU and LSTM Testing function

In [38]:
def run_mixed_RNN_tests(model_type,train_dataloader,validation_dataloader,ood_dataloader,max_length,num_labels=2,mix=True,epochs=10,num_runs=5):
    max_vals=[]
    lr=3e-5
    oods=[]
    times=[]
    for i in range(0,num_runs):
        emixRNN=emixRNNClassifier(model_name=pretrained_model_name,num_labels=num_labels,
                                  vocab_size=tokenizer.vocab_size,model_type=model_type,device='cuda')
        emixRNN=emixRNN.to('cuda')
        optimizer=torch.optim.AdamW(emixRNN.parameters(),lr=lr)
        criterion=nn.CrossEntropyLoss()
        trainer=emixTrainer(tokenizer,emixRNN,device="cuda",criterion=criterion,optimizer=optimizer,
                            train_dataloader=train_dataloader,ood_dataloader=ood_dataloader,
                            validation_dataloader=validation_dataloader,epochs=epochs,max_length=max_length)
        
        start_time=time.time()
        max_val,ood_acc,best_epoch=trainer.train(mix=mix)
        end_time=time.time()
        print("-"*50)
        print(f"{i+1} model: Max validation={max_val}, Best Epoch={best_epoch}, ood_acc={ood_acc}, time taken={end_time-start_time}")
        print("-"*50)
        times.append(end_time-start_time)
        max_vals.append(max_val)
        oods.append(ood_acc)
        del emixRNN
        del optimizer
        del criterion
    torch.cuda.memory_allocated()
    torch.cuda.empty_cache()
    return np.mean(max_vals),np.std(max_vals),np.mean(times)

## MNLI

In [39]:
train_dataloader = get_dataloader('Datasets/IMDB_500.csv',batch_sizes=[16])[0]
validation_dataloader=get_dataloader('Datasets/IMDB_1000_ssmba_val.csv')[0]
ood_dataloader=get_dataloader('Datasets/SST-2_1000_ssmba_test.csv')[0]

### BERT Tests

In [None]:
max_lengths=[32,256]
mixes=[True]
for length in max_lengths:
    for mix in mixes:
        avg_val,std_val,avg_time=run_mixed_BERT_tests(train_dataloader,validation_dataloader,ood_dataloader,
                                              max_length=length,num_labels=3,epochs=10,mix=mix,num_runs=5)
        with open("outputs.txt",'a') as file:
            if mix:
                prefix="e"
            else:
                prefix="no"
            file.write(f"\n{prefix}mix-BERT, Unbiased, IMDB, val_acc={avg_val}, stdc={std_val}")

### RNN tests

In [None]:
max_lengths=[32,256]
mixes=[True]
model_types=['GRU','LSTM']
for mtype in model_types:
    for length in max_lengths:
        for mix in mixes:
            avg_val,std,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=length,num_labels=3,mix=mix,epochs=10,num_runs=5)
            with open("MNLI_bias_outputs.txt",'a') as file:
                if mix:
                    prefix="e"
                else:
                    prefix="no"
                file.write(f"\n{prefix}mix-{mtype}, Unbiased, IMDB, val_acc={avg_val}, stdc={std}")

## MNLI

In [35]:
train_dataloader = get_dataloader('Datasets/MNLI/MNLI_ssmba_train.csv',batch_sizes=[16])[0]
validation_dataloader=get_dataloader('Datasets/MNLI/MNLI_ssmba_val.csv')[0]
ood_dataloader=get_dataloader('Datasets/MNLI/MNLI_ssmba_test.csv')[0]

### BERT

In [None]:
max_lengths=[512,32]
mixes=[True,False]

for length in max_lengths:
    for mix in mixes:
        avg_val,avg_ood,avg_time=run_mixed_BERT_tests(train_dataloader,validation_dataloader,ood_dataloader,
                                              max_length=length,num_labels=3,epochs=10,mix=mix,num_runs=5)
        with open("MNLI_outputs.txt",'a') as file:
            if mix:
                prefix="e"
            else:
                prefix="no"
            file.write(f"\n{prefix}mix-BERT, Unbiased, window_size={length}, MNLI, val_acc={avg_val}, ood_acc={avg_ood}, time_taken={avg_time}")

### RNN

In [None]:
max_lengths=[512,32]
mixes=[True,False]
model_types=['GRU','LSTM']
for mtype in model_types:
    for length in max_lengths:
        for mix in mixes:
            avg_val,avg_ood,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=length,num_labels=3,mix=mix,epochs=10,num_runs=5)
            with open("MNLI_outputs.txt",'a') as file:
                if mix:
                    prefix="e"
                else:
                    prefix="no"
                file.write(f"\n{prefix}mix-{mtype}, Unbiased, window_size={length}, MNLI, val_acc={avg_val}, ood_acc={avg_ood}, time_taken={avg_time}")

## Comparison with EDA and SSMBA

In [31]:
train_dataloader = get_dataloader('Datasets/IMDB_500.csv',batch_sizes=[16])[0]
validation_dataloader=get_dataloader('Datasets/IMDB_1000_ssmba_val.csv')[0]
ood_dataloader=get_dataloader('Datasets/SST-2_1000_ssmba_test.csv')[0]

In [32]:
eda_dataloader=get_dataloader('Datasets/IMDB_4500_eda8.csv')[0]
ssmba_dataloader=get_dataloader('Datasets/IMDB_no_bias_8_ssmba_train.csv')[0]

In [33]:
num_runs=5

In [35]:
methods=[
    'Normal',
    'Emix',
    'eda',
    'ssmba'
    ]

In [None]:
for method in methods:
    if method=='eda':
        training_data=eda_dataloader
    elif method=='ssmba':
        training_data=ssmba_dataloader
    else:
        training_data=train_dataloader

    if method=='Emix':
        avg_val,avg_ood,avg_time=run_mixed_BERT_tests(training_data,validation_dataloader,ood_dataloader,
                                                max_length=256,num_labels=2,epochs=10,mix=True,num_runs=num_runs)
    else:
        avg_val,avg_ood,avg_time=run_mixed_BERT_tests(training_data,validation_dataloader,ood_dataloader,
                                                max_length=256,num_labels=2,epochs=10,mix=False,num_runs=num_runs)
    with open("IMDB_comparison_outputs.txt",'a') as file:
        file.write(f"\n{method}-BERT, Unbiased, window_size=256, IMDB, val_acc={avg_val}, ood_acc={avg_ood}, time_taken={avg_time}")

In [None]:
model_types=['GRU','LSTM']
for mtype in model_types:
    for method in methods:
        if method=='eda':
            training_data=eda_dataloader
        elif method=='ssmba':
            training_data=ssmba_dataloader
        else:
            training_data=train_dataloader

        if method=='Emix':
            avg_val,avg_ood,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=256,num_labels=2,mix=True,epochs=10,num_runs=num_runs)
        else:
            avg_val,avg_ood,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=256,num_labels=2,mix=False,epochs=10,num_runs=num_runs)
        with open("IMDB_comparison_outputs.txt",'a') as file:
            file.write(f"\n{method}-{mtype}, Unbiased, window_size=256, IMDB, val_acc={avg_val}, ood_acc={avg_ood}, time_taken={avg_time}")

## MNLI SSMBA and EDA Comparison

In [69]:
train_dataloader = get_dataloader('Datasets/MNLI/MNLI_ssmba_train.csv',batch_sizes=[16])[0]
validation_dataloader=get_dataloader('Datasets/MNLI/MNLI_ssmba_val.csv')[0]
ood_dataloader=get_dataloader('Datasets/MNLI/MNLI_ssmba_test.csv')[0]

In [70]:
eda_dataloader=get_dataloader('Datasets/MNLI_eda1.csv',batch_sizes=[16])[0]
ssmba_dataloader=get_dataloader('Datasets/MNLI_no_bias_1_ssmba_train.csv')[0]

In [None]:
methods=[
         'Normal',
         'Emix',
         'eda','ssmba']

for method in methods:
    if method=='eda':
        training_data=eda_dataloader
    elif method=='ssmba':
        training_data=ssmba_dataloader
    else:
        training_data=train_dataloader

    if method=='Emix':
        avg_val,avg_ood,avg_time=run_mixed_BERT_tests(training_data,validation_dataloader,ood_dataloader,
                                                max_length=256,num_labels=3,epochs=10,mix=True,num_runs=num_runs)
    else:
        avg_val,avg_ood,avg_time=run_mixed_BERT_tests(training_data,validation_dataloader,ood_dataloader,
                                                max_length=256,num_labels=3,epochs=10,mix=False,num_runs=num_runs)
    with open("MNLI_comparison_outputs.txt",'a') as file:
        file.write(f"\n{method}-BERT, Unbiased, window_size=256, MNLI, val_acc={avg_val}, ood_acc={avg_ood}, time_taken={avg_time}")

In [74]:
eda_dataloader=get_dataloader('Datasets/MNLI_eda8.csv',batch_sizes=[16])[0]
ssmba_dataloader=get_dataloader('Datasets/MNLI_no_bias_8_ssmba_train.csv')[0]

In [None]:
methods=['eda','ssmba']
model_types=['GRU','LSTM']
for mtype in model_types:
    for method in methods:
        if method=='eda':
            training_data=eda_dataloader
        elif method=='ssmba':
            training_data=ssmba_dataloader
        else:
            training_data=train_dataloader

        if method=='Emix':
            avg_val,avg_ood,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=256,num_labels=3,mix=True,epochs=10,num_runs=num_runs)
        else:
            avg_val,avg_ood,avg_time=run_mixed_RNN_tests(mtype,train_dataloader,validation_dataloader,ood_dataloader,
                                        max_length=256,num_labels=3,mix=False,epochs=10,num_runs=num_runs)
        with open("MNLI_comparison_outputs.txt",'a') as file:
            file.write(f"\n{method}-{mtype}, Unbiased, window_size=256, MNLI, val_acc={avg_val}, std={avg_ood}, time_taken={avg_time}")