In [None]:
import os
import time
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModel, AutoConfig

In [None]:
def seed_everything():
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    
seed_everything()

Some of the ideas are taken from the below notebook

https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-fit


In [None]:
class CONFIG:
    checkpoint='bert-base-uncased'
    tokenizer=AutoTokenizer.from_pretrained(checkpoint)
    bert_config=AutoConfig.from_pretrained(checkpoint)
    bert_model=AutoModel.from_pretrained(checkpoint)
    
    hidden_size=bert_config.hidden_size
    pad_token_id=tokenizer.pad_token_id
    max_seq_len=tokenizer.model_max_length
    
    batch_size=16
    folds=5
    learning_rate=1e-5
    weight_decay=1e-2
    optimizer='AdamW'
    epochs=8
    clip_gradient_norm=1.0
    eval_every=60
    
    device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_df=pd.read_csv('../input/commonlit-kfold-dataset/fold_train.csv')
train_df.head()

In [None]:
CONFIG.tokenizer.save_pretrained('bert_tokenizer')

In [None]:
class CommonLitDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.excerpts=df.excerpt.values
        self.targets=df.target.values
        self.tokenizer=CONFIG.tokenizer
        self.pad_token_id=CONFIG.pad_token_id
        self.max_seq_len=CONFIG.max_seq_len
    
    def get_tokenized_features(self, excerpt):
        inputs=self.tokenizer(excerpt, truncation=True)
        
        input_ids=inputs['input_ids']
        attention_mask=inputs['attention_mask']
        token_type_ids=inputs['token_type_ids']
        
        input_len=len(input_ids)
        pad_len=self.max_seq_len-input_len
        input_ids+=[self.pad_token_id]*pad_len
        attention_mask+=[0]*pad_len
        token_type_ids+=[0]*pad_len
        
        return {
            'seq_len': input_len,
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask
        }
        
    def __getitem__(self, idx):
        excerpt=self.excerpts[idx]
        target=self.targets[idx]
        features=self.get_tokenized_features(excerpt)
        
        return {
            'seq_len': features['seq_len'],
            'input_ids': torch.tensor(features['input_ids'], dtype=torch.long),
            'token_type_ids': torch.tensor(features['token_type_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long),
            'labels': torch.tensor(target, dtype=torch.float32)
        }
    
    def __len__(self):
        return len(self.targets)

In [None]:
def freeze_bert_layers(bert_model):
    max_freeze_layer=5
    for n,p in bert_model.named_parameters():
        if ('embedding' in n) or ('layer' in n and int(n.split('.')[2]) <= max_freeze_layer):
            p.requires_grad=False

# Model

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self):
        super(ProjectionHead, self).__init__()
        self.linear1=nn.Linear(CONFIG.hidden_size, 2*CONFIG.hidden_size)
        self.bn=nn.BatchNorm1d(2*CONFIG.hidden_size)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(0.3)
        self.out=nn.Linear(2*CONFIG.hidden_size, 1)
    def forward(self, x):
        x=self.linear1(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.dropout(x)
        x=self.out(x)
        return x
    
class CommonLitModel(nn.Module):
    def __init__(self):
        super(CommonLitModel, self).__init__()
        self.bert=CONFIG.bert_model
        freeze_bert_layers(self.bert)
        
        self.layer_norm=nn.LayerNorm(CONFIG.hidden_size)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(0.4)
        
        self.proj_head=ProjectionHead()
    def forward(self, input_ids, attention_mask, token_type_ids,
                output_hidden_states=False):
        bert_output=self.bert(input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids,
                              output_hidden_states=output_hidden_states)
        
        bert_pooler_output=bert_output.pooler_output
        bert_pooler_output=self.layer_norm(bert_pooler_output)
        bert_pooler_output=self.relu(bert_pooler_output)
        bert_pooler_output=self.dropout(bert_pooler_output)
        
        y=self.proj_head(bert_pooler_output)
        return y

# Training

In [None]:
class Trainer:
    def __init__(self, model, train_dataloader, test_dataloader):
        self.model=model
        self.optimizer=torch.optim.AdamW(model.parameters(),
                                         lr=CONFIG.learning_rate,
                                         weight_decay=CONFIG.weight_decay)
        self.schedular=torch.optim.lr_scheduler.OneCycleLR(self.optimizer, 
                                                           max_lr=CONFIG.learning_rate,
                                                           epochs=CONFIG.epochs,
                                                           steps_per_epoch=len(train_dataloader))
        self.criterion=nn.MSELoss()
        self.train_dataloader=train_dataloader
        self.test_dataloader=test_dataloader
        self.train_loss_=[]
        self.val_loss_=[]
        self.best_loss=None
        self.iter_count=0
        self.best_iter=0
        
    def train_ops(self, y, ypred):
        self.optimizer.zero_grad()
        loss=self.criterion(ypred, y)
        loss.backward()
        self.optimizer.step()
        self.schedular.step()
        return loss
    
    
    def evaluate(self):
        ytrue=[]
        ypred=[]
        self.model.eval()
        for batch in self.test_dataloader:
            batch_seq_lens=batch['seq_len']
            batch_max_seq_len=torch.max(batch['seq_len'])
            
            input_ids=batch['input_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            attention_mask=batch['attention_mask'][:, :batch_max_seq_len].to(CONFIG.device)
            token_type_ids=batch['token_type_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            labels=batch['labels']
            
            ytrue+=labels.tolist()
            with torch.no_grad():
                yhat=self.model(input_ids, attention_mask, token_type_ids).view(-1).detach().cpu()
                ypred+=yhat.tolist()
        ytrue=torch.tensor(ytrue, dtype=torch.float32)
        ypred=torch.tensor(ypred, dtype=torch.float32)
        val_loss=self.criterion(ypred, ytrue)
        return val_loss.item()
    
    def train_epoch(self):
        t1=time.time()
        self.model.train()
        for batch_id, batch in enumerate(self.train_dataloader):
            self.iter_count+=1
            if self.iter_count%CONFIG.eval_every==0:
                val_loss=self.evaluate()
                self.val_loss_.append(val_loss)
                if (self.best_loss is None) or (self.best_loss > val_loss):
                    self.best_loss=val_loss
                    self.best_iter=self.iter_count
                    torch.save(self.model, "best_model.pt")
                    
                torch.save(self.model, "model_{}.pt".format(self.iter_count))
                print("==="*10)
                print("Iteration:{} | ValLoss:{:.4f} | BestIteration:{}".format(self.iter_count, val_loss, self.best_iter))
            
            batch_seq_lens=batch['seq_len']
            batch_max_seq_len=torch.max(batch['seq_len'])
            
            input_ids=batch['input_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            attention_mask=batch['attention_mask'][:, :batch_max_seq_len].to(CONFIG.device)
            token_type_ids=batch['token_type_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            labels=batch['labels'].to(CONFIG.device)
            
            ypred=self.model(input_ids, attention_mask, token_type_ids).view(-1)
            loss=self.train_ops(labels, ypred)
            self.train_loss_.append(loss.item())
            
    def train(self):
        for e in range(CONFIG.epochs):
            self.train_epoch()

In [None]:
model=CommonLitModel()
model=model.to(CONFIG.device)


fold_train_df=train_df[train_df.kfold!=0].copy()
fold_val_df  =train_df[train_df.kfold==0].copy()

train_dataset=CommonLitDataset(fold_train_df)
val_dataset=CommonLitDataset(fold_val_df)

train_dataloader=torch.utils.data.DataLoader(train_dataset, batch_size=CONFIG.batch_size,
                                             shuffle=True, pin_memory=True, drop_last=False)


val_dataloader=torch.utils.data.DataLoader(val_dataset, batch_size=CONFIG.batch_size,
                                             shuffle=False, pin_memory=True, drop_last=False)


In [None]:
trainer=Trainer(model, train_dataloader, val_dataloader)
trainer.train()

In [None]:
plt.plot(trainer.train_loss_)

In [None]:
plt.plot(trainer.val_loss_)