## DistilBERT finetuning with ArcMargin

In [None]:
import os
import copy
import math
import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import random
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold

import transformers
from transformers import (XLMRobertaTokenizer, XLMRobertaModel,
                          DistilBertTokenizer, DistilBertModel)

In [None]:
train=pd.read_csv("../input/amazon-ml-challenge-2021-hackerearth/train.csv", escapechar = "\\", quoting = csv.QUOTE_NONE)
train.head()

The following histogram gives us an idea that roughly how many words are there in each title. It is not a precise count of the tokens fed to the model because DistilBERT tokenizer does a more sophisticated function than simply splitting the sentence from its white spaces.

max_length is set to 30 according to the histogram. But you can safely change it.

In [None]:
def set_seed(seed=42):
    os.environ['PYTHONHASHSEED']=str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
set_seed()

In [None]:
class CFG:
    DistilBERT = False # if set to False, BERT model will be used
    bert_hidden_size = 768
    num_classes=9919
    batch_size = 192
    epochs = 4
    num_workers = 2
    learning_rate = 1e-5 #3e-5
    scheduler = "ReduceLROnPlateau"
    step = 'epoch'
    patience = 2
    factor = 0.8
    dropout = 0.5
    model_path = "/kaggle/working"
    max_length = 64
    model_save_name = "model.pt"
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

Loading the model and its tokenizer from amazing HuggingFace model hub. As mentioned before, this model has been pre-trained on indonesian wikipedia.

In [None]:
if CFG.DistilBERT:
    model_name='distilbert-base-uncased'
    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    bert_model = DistilBertModel.from_pretrained(model_name)
else:
    model_name='xlm-roberta-base'
    tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
    bert_model = XLMRobertaModel.from_pretrained(model_name)

See an example

In [None]:
text = train['TITLE'].values[np.random.randint(0, len(train) - 1, 1)[0]]
print(f"Text of the title: {text}")
encoded_input = tokenizer(text, return_tensors='pt')
print(f"Input tokens: {encoded_input['input_ids']}")
decoded_input = tokenizer.decode(encoded_input['input_ids'][0])
print(f"Decoded tokens: {decoded_input}")
output = bert_model(**encoded_input)
print(f"last layer's output shape: {output.last_hidden_state.shape}")

## Dataset

Encoding label_group coulmn to numeric labels so we can feed them to the model and loss function.

In [None]:
id2lbl={lbl: idx for idx,lbl in enumerate(list(train["BROWSE_NODE_ID"].unique()))}
lbl2id={lbl:idx for idx,lbl in id2lbl.items()}

In [None]:
def create_folds(data, num_splits):
    data["kfold"] = -1
    data = data.sample(frac=1).reset_index(drop=True)
    y=data["BROWSE_NODE_ID"]
    kf = StratifiedKFold(n_splits=num_splits)
    for f, (t_, v_) in enumerate(kf.split(X=data, y=y)):
        data.loc[v_, 'kfold'] = f
    return data


In [None]:
train=create_folds(train, 5)

In [None]:
train=train.loc[train.kfold.isin([1,2,3])]
train=train.reset_index(drop=True)
train.head()

In [None]:
temp=train.dropna(subset=['TITLE'])
temp=temp.reset_index(drop=True)

In [None]:
temp.head()

In [None]:
temp["BROWSE_NODE_ID"]=temp["BROWSE_NODE_ID"].map(id2lbl)

In [None]:
class TextDataset(Dataset):
  def __init__(self,data,tokenizer,mode="train", max_length=None):
    super(TextDataset, self).__init__()
    self.sentence=data["TITLE"]
    if mode != "test":
        self.label=data["BROWSE_NODE_ID"]
    self.tokenizer=tokenizer
    self.max_length=max_length
    self.mode=mode

  def __len__(self):
    return len(self.sentence)
  
  def __getitem__(self,idx):
    inp_tokens=self.tokenizer.encode_plus(self.sentence[idx], 
                                          padding="max_length", 
                                          add_special_tokens=True,
                                          max_length=self.max_length,
                                          truncation=True)
    item={
        "input_ids":torch.tensor(inp_tokens.input_ids,dtype=torch.long),
        "attention_mask":torch.tensor(inp_tokens.attention_mask,dtype=torch.long)
    }
    if self.mode != "test":
        item['labels'] = torch.tensor(self.label[idx], dtype=torch.long)

    return item

In [None]:
dataset = TextDataset(temp, tokenizer, max_length=CFG.max_length)
dataloader = DataLoader(dataset, 
                         batch_size=CFG.batch_size, 
                         num_workers=CFG.num_workers, 
                         shuffle=True)

In [None]:
len(dataset)

In [None]:
next(iter(dataloader))

In [None]:
# code from https://github.com/ronghuaiyang/arcface-pytorch/blob/47ace80b128042cd8d2efd408f55c5a3e156b032/models/metrics.py#L10

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=CFG.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

In [None]:
class Model(nn.Module):
    def __init__(self, 
                 bert_model, 
                 num_classes=CFG.num_classes, 
                 last_hidden_size=CFG.bert_hidden_size):
        
        super().__init__()
        self.bert_model = bert_model
        self.arc_margin = ArcMarginProduct(last_hidden_size, 
                                           num_classes, 
                                           s=30.0, 
                                           m=0.50, 
                                           easy_margin=False)
    
    def get_bert_features(self, batch):
        output = self.bert_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        last_hidden_state = output.last_hidden_state # shape: (batch_size, seq_length, bert_hidden_dim)
        CLS_token_state = last_hidden_state[:, 0, :] # obtaining CLS token state which is the first token.
        return CLS_token_state
    
    def forward(self, batch):
        CLS_hidden_state = self.get_bert_features(batch)
        output = self.arc_margin(CLS_hidden_state, batch['labels'])
        return output

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()
    
    def reset(self):
        self.avg, self.sum, self.count = [0]*3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count
    
    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def one_epoch(model, 
              criterion, 
              loader,
              optimizer=None, 
              lr_scheduler=None, 
              mode="train", 
              step="batch"):
    
    loss_meter = AvgMeter()
    acc_meter = AvgMeter()
    
    tqdm_object = tqdm(loader, total=len(loader))
    for batch in tqdm_object:
        batch = {k: v.to(CFG.device) for k, v in batch.items()}
        preds = model(batch)
        loss = criterion(preds, batch['labels'])
        if mode == "train":
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step == "batch":
                lr_scheduler.step()
                
        count = batch['input_ids'].size(0)
        loss_meter.update(loss.item(), count)
        
        accuracy = get_accuracy(preds.detach(), batch['labels'])
        acc_meter.update(accuracy.item(), count)
        if mode == "train":
            tqdm_object.set_postfix(train_loss=loss_meter.avg, accuracy=acc_meter.avg, lr=get_lr(optimizer))
        else:
            tqdm_object.set_postfix(valid_loss=loss_meter.avg, accuracy=acc_meter.avg)
    
    return loss_meter, acc_meter

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

def get_accuracy(preds, targets):
    """
    preds shape: (batch_size, num_labels)
    targets shape: (batch_size)
    """
    preds = preds.argmax(dim=1)
    acc = (preds == targets).float().mean()
    return acc

In [None]:
def train_eval(epochs, model, train_loader, valid_loader, 
               criterion, optimizer, lr_scheduler=None):
    
    best_loss = float('inf')
    best_model_weights = copy.deepcopy(model.state_dict())
    
    for epoch in range(epochs):
        print("*" * 30)
        print(f"Epoch {epoch + 1}")
        current_lr = get_lr(optimizer)
        
        model.train()
        train_loss, train_acc = one_epoch(model, 
                                          criterion, 
                                          train_loader, 
                                          optimizer=optimizer,
                                          lr_scheduler=lr_scheduler,
                                          mode="train",
                                          step=CFG.step)                     
        model.eval()
        with torch.no_grad():
            valid_loss, valid_acc = one_epoch(model, 
                                              criterion, 
                                              valid_loader, 
                                              optimizer=None,
                                              lr_scheduler=None,
                                              mode="valid")
        
        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            best_model_weights = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f'{CFG.model_path}/{CFG.model_save_name}')
            print("Saved best model!")
        
        if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            lr_scheduler.step(valid_loss.avg)
            if current_lr != get_lr(optimizer):
                print("Loading best model weights!")
                model.load_state_dict(torch.load(f'{CFG.model_path}/{CFG.model_save_name}', 
                                                 map_location=CFG.device))
        
        print("*" * 30)

In [None]:
len(temp['TITLE'])

In [None]:
len(temp['BROWSE_NODE_ID'])

In [None]:
train_df, valid_df = train_test_split(temp, 
                                      test_size=0.33, 
                                      shuffle=True, 
                                      random_state=42)
train_df=train_df.reset_index(drop=True)
valid_df=valid_df.reset_index(drop=True)

train_dataset = TextDataset(train_df, tokenizer, max_length=CFG.max_length)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=CFG.batch_size, 
                                           num_workers=CFG.num_workers, 
                                           shuffle=True)

valid_dataset = TextDataset(valid_df, tokenizer, max_length=CFG.max_length)
valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                                           batch_size=CFG.batch_size, 
                                           num_workers=CFG.num_workers, 
                                           shuffle=False)

In [None]:
model = Model(bert_model).to(CFG.device)
print(model)

In [None]:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.learning_rate)
if CFG.scheduler == "ReduceLROnPlateau":
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                              mode="min", 
                                                              factor=CFG.factor, 
                                                              patience=CFG.patience)

train_eval(CFG.epochs, model, train_loader, valid_loader,
           criterion, optimizer, lr_scheduler=lr_scheduler)

In [None]:
!mkdir tokenizer
tokenizer.save_pretrained("./tokenizer")
torch.save(model.state_dict(), "final.pt")

In [None]:
torch.save(model,'RoBERTArcFace.pth')