In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader, Dataset

import numpy as np
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel

In [None]:
MODEL_NAME = "nreimers/MiniLM-L6-H384-uncased" # Session 11 to know about MiniLM

In [None]:
# load tokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
# load data

ag_news = load_dataset("ag_news")['train']

In [None]:
ag_news

In [None]:
# split data

supervised, unsupervised, val = random_split(ag_news, [10000, 100000, 10000])

### Train a Tfidf augmentation

In [None]:
# # adapted from https://github.com/makcedward/nlpaug/blob/master/example/tfidf-train_model.ipynb

import re
# pip install numpy requests nlpaug
import nltk
import nlpaug.augmenter.word as naw 
import nlpaug.model.word_stats as nmw

# def _tokenizer(text, token_pattern=r"(?u)\b\w\w+\b"):
#     token_pattern = re.compile(token_pattern)
#     return token_pattern.findall(text)

# # Tokenize input
# train_x_tokens = [_tokenizer(x['text'].lower()) for x in ag_news]

# # Train TF-IDF model
# tfidf_model = nmw.TfIdf()
# tfidf_model.train(train_x_tokens)
# tfidf_model.save('.')

# # Load TF-IDF augmenter
# tf_idf_aug = naw.TfIdfAug(model_path='.', tokenizer=_tokenizer, stopwords=nltk.corpus.stopwords.words('english'))

In [None]:
# tf_idf_aug.augment('my computer is broken')

In [None]:
del_aug = naw.random.RandomWordAug(action='delete')

### DataLoaders

In [None]:
class UdaData(Dataset):
    
    def __init__(self, data, supervised = True):
        
        self.data = data
        
        self.supervised = supervised
        
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        
        x, y = self.data[idx]['text'], self.data[idx]['label']
        
        if self.supervised:
            
            return x, y
        
        else:
            
            return x, del_aug.augment(x)

In [None]:
def batch_encode(texts, max_length):
    
    return tokenizer.batch_encode_plus(texts, # ['text 1', 'text 2', ...]
                                       return_tensors='pt',
                                       padding=True,
                                       truncation=True,
                                       max_length=max_length,
                                       return_token_type_ids=False,
                                       return_attention_mask=False)['input_ids']

In [None]:
def collate_sup(batch, max_length=100):
    
    texts = [b[0] for b in batch]
    
    y = torch.LongTensor([b[1] for b in batch])
    
    x = batch_encode(texts, max_length)
    
    return [x, y]

def collate_unsup(batch, max_length=100):
    
    texts_1 = [b[0] for b in batch]
    
    texts_2 = [b[1] for b in batch]
    
    x_1 = batch_encode(texts_1, max_length)
    
    x_2 = batch_encode(texts_2, max_length)
    
    return [x_1, x_2]

In [None]:
sup_batch_size = 128
mu = 2

sup_loader = DataLoader(UdaData(supervised, supervised=True), batch_size=sup_batch_size, shuffle=True, collate_fn=collate_sup, num_workers=15)

unsup_loader = DataLoader(UdaData(unsupervised, supervised=False), batch_size=sup_batch_size * mu, shuffle=True, collate_fn=collate_unsup, num_workers=15)

val_loader = torch.utils.data.DataLoader(UdaData(val, supervised=True), batch_size=100, shuffle=False, collate_fn=collate_sup, num_workers=15)

In [None]:
for x1, y in sup_loader:
    print(x)
    break

In [None]:
y

### Model

In [None]:
import copy

def delete_some_layers(model):
        
    oldModuleList = model.encoder.layer
    
    newModuleList = nn.ModuleList()
    
    # just keep encoder layers [0, 2, 4]
    for i in range(0, 6, 2):
        newModuleList.append(oldModuleList[i])

    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

In [None]:
class MinilmClassifier(nn.Module):
    
    def __init__(self, num_classes, num_layers=2):
        super().__init__()
        
        minilm = AutoModel.from_pretrained(MODEL_NAME)
        
        self.encoder = delete_some_layers(minilm)
        
        hidden_size = self.encoder.config.hidden_size
        
        self.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )
        
    def forward(self, x):
        
        attention_mask = (x != 0)
        
        pooled = self.encoder(input_ids=x, attention_mask=attention_mask)['pooler_output']
        
        return self.fc(pooled)

In [None]:
model = MinilmClassifier(num_classes=5)

In [None]:
model(x).shape

### Training loop

In [None]:
def train_uda(model: nn.Module,
                    opt: torch.optim,
                    sup_loader: torch.utils.data.DataLoader,
                    unsup_loader: torch.utils.data.DataLoader,
                    alpha: float=0.5):
    
    model.train()
    
    for param in model.parameters():
        device = param.device
        break
    
    losses = []
    
    pbar = tqdm(sup_loader)
    
    unsup_iter = iter(unsup_loader)
    
    for batch_sup in pbar:
        
        model.zero_grad()
        
        # labelled data
        x_1, y = batch_sup
        x_1, y = x_1.to(device), y.to(device)
        
        # supervised cross-entropy loss
        logits_sup = model(x_1)
        loss_sup = F.cross_entropy(logits_sup, y)
        
        # unlabelled data
        try:
            x_2, x_aug = next(unsup_iter)
        except StopIteration:
            unsup_iter = iter(unsup_loader)
            x_2, x_aug = next(unsup_iter)
                
        x_2, x_aug = x_2.to(device), x_aug.to(device)
        
        # prediction for the non-augmented data
        with torch.no_grad():
            logits_x_2 = model(x_2)
        
        # prediction for the augmented data
        logits_x_aug = model(x_aug)
        
        # cross-entropy between the non-augmented and augmented
        loss_unsup = F.kl_div(F.log_softmax(logits_x_aug, dim=1), F.softmax(logits_x_2, dim=1), reduction='none').sum(1)
        
        # sum losses
        loss = loss_sup + alpha * loss_unsup.mean()
        
        loss.backward()

        opt.step()
        
        loss_item = loss.item()
        
        losses.append(loss_item)
        
        pbar.set_description(f'train_loss = {np.array(losses).mean(): .3f}')
        
    return np.array(losses).mean()

@torch.no_grad()
def validate(model: nn.Module, dataloader: torch.utils.data.DataLoader):
    
    model.eval()
    
    for param in model.parameters():
        device = param.device
        break
     
    labels_all = []
    logits_all = []
    
    for x, y in dataloader:

        x, y = x.to(device), y.to(device)

        logits = model(x)
        
        labels_all += y.cpu().numpy().tolist()
        logits_all += logits.cpu().numpy().tolist()
        
    prediction = np.argmax(np.array(logits_all), axis=-1)
    
    acc = accuracy_score(labels_all, prediction)
                    
    return acc

### Training

In [None]:
# initialize model

model = MinilmClassifier(num_classes=4).cuda()

opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
for t in range(5):
    train_uda(model, opt, sup_loader, unsup_loader, alpha=1.)
    val_acc = validate(model, val_loader)
    print(val_acc)