In [None]:
import os
import time

import numpy as np
from scipy.stats import norm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence

from transformers import (AutoTokenizer, AutoModel, AutoConfig, 
                          AdamW,
                          get_linear_schedule_with_warmup,
                          get_cosine_schedule_with_warmup)

In [None]:
def seed_everything(s):
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    
seed_everything(42)

https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-infer-3



In [None]:
class CONFIG:
    env='prod' # Set test for Testing.
    checkpoint='roberta-base'
    pretrain_path='../input/clrp-roberta-base-pretrain/clrp_roberta_base'
    tokenizer=AutoTokenizer.from_pretrained(checkpoint)
    base_config=AutoConfig.from_pretrained(checkpoint)
    
    hidden_size=base_config.hidden_size
    pad_token_id=tokenizer.pad_token_id
    max_seq_len=tokenizer.model_max_length
    
    FREEZE_LAYERS_START=0
    TRAIN_MAX_ITERS=680
    TRAIN_WARMUP_STEPS=204
    TRAIN_SAMPLES_PER_BATCH=20
    
    batch_size=20
    folds=5
    bins=9
    train_sample_bins=10
    
    learning_rate=2e-5
    weight_decay=0.01
    optimizer='AdamW'
    epochs=8
    clip_gradient_norm=1.0
    eval_every=10
    
    device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if CONFIG.env=='test':
    CONFIG.TRAIN_SAMPLES_PER_BATCH=4
    CONFIG.TRAIN_MAX_ITERS=3
    CONFIG.epochs=1
    CONFIG.eval_every=1
    
print("Device:", CONFIG.device)

In [None]:
def get_qunatile_boundaries(df, bins=10):
    df=df.copy()
    qs=[]
    for i in np.arange(1/bins, 1.1, 1/bins):
        q=train_df.target.quantile(i)
        qs.append(q)
    return qs

def get_quantile(target, qs):
    for i,q in enumerate(qs):
        if target<=q:
            return i

        
def get_bin_ranges(df):
    df=df.copy()
    bin_ranges=[]
    
    min_target=train_df.target.min()
    max_target=train_df.target.max()
    
    min_std=train_df[train_df.target==min_target].standard_error.min()
    max_std=train_df[train_df.target==max_target].standard_error.max()

    min_val=min_target-min_std
    max_val=max_target+max_std
    
    bin_values=np.arange(min_val, max_val, 0.5)
    start_bin=(-1e9, bin_values[0])
    end_bin=(bin_values[-1], 1e9)
    
    bin_ranges.append(start_bin)
    for i in range(1, len(bin_values)):
        bin_ranges.append( (bin_values[i-1], bin_values[i]) )
    bin_ranges.append(end_bin)
    return bin_ranges

def get_bin_distribution(row, bin_ranges):
    mu=row.target
    scale=0.2
    bins=[]
    
    for bin_range in bin_ranges:
        s=bin_range[0]
        e=bin_range[1]
        
        cdf1=norm.cdf(s, mu, scale)
        cdf2=norm.cdf(e, mu, scale)
        
        cdf=cdf2-cdf1
        
        bins.append(cdf)
    return bins

In [None]:
train_df=pd.read_csv('../input/commonlit-kfold-dataset/fold_train.csv')
bin_ranges=get_bin_ranges(train_df)
#Update Bins in the configuration
CONFIG.bins=len(bin_ranges)

print(bin_ranges)

In [None]:
train_qs=get_qunatile_boundaries(train_df, CONFIG.train_sample_bins)
train_df['q']=train_df.target.apply(get_quantile, args=(train_qs, ))
train_df['ybin']=train_df.apply(get_bin_distribution, args=(bin_ranges, ), axis=1)

train_df.head()

# Datasets

In [None]:
class CommonLitDataset(torch.utils.data.Dataset):
    def __init__(self, df, phase='train'):
        self.excerpts=df.excerpt.values
        self.targets=df.target.values
        self.standard_errors=df.standard_error.values
        self.ybin=df.ybin.values
        
        self.tokenizer=CONFIG.tokenizer
        self.pad_token_id=CONFIG.pad_token_id
        self.max_seq_len=CONFIG.max_seq_len
    
    def get_tokenized_features(self, excerpt):
        inputs=self.tokenizer(excerpt, truncation=True)
        
        input_ids=inputs['input_ids']
        attention_mask=inputs['attention_mask']
        
        input_len=len(input_ids)
        pad_len=self.max_seq_len-input_len
        input_ids+=[self.pad_token_id]*pad_len
        attention_mask+=[0]*pad_len
        
        return {
            'seq_len': input_len,
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }
        
    def __getitem__(self, idx):
        excerpt=self.excerpts[idx]
        target=self.targets[idx]
        ybin=self.ybin[idx]
        sigma=self.standard_errors[idx]
        features=self.get_tokenized_features(excerpt)
        return {
            'seq_len': features['seq_len'],
            'input_ids': torch.tensor(features['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long),
            'yreg': torch.tensor(target, dtype=torch.float32),
            'ybin': torch.tensor(ybin, dtype=torch.float32),
            'sigmas': torch.tensor(sigma, dtype=torch.float32)
        }
    
    def __len__(self):
        return len(self.targets)

# Train Data Sampler

In [None]:
class TrainDataSampler:
    def __init__(self, batch_size, df):
        self.qmap={}
        self.batch_size=batch_size
        self.batch_fraction=1.0
        self.df=df.copy()
        
        self.tokenizer=CONFIG.tokenizer
        self.pad_token_id=CONFIG.pad_token_id
        self.max_seq_len=CONFIG.max_seq_len
        
        for i in range(CONFIG.train_sample_bins):
            ids=self.df[self.df.q==i].id.values
            np.random.shuffle(ids)
            self.qmap[i]=ids
    
    def get_tokenized_features(self, excerpt):
        inputs=self.tokenizer(excerpt, truncation=True)
        
        input_ids=inputs['input_ids']
        attention_mask=inputs['attention_mask']
        
        input_len=len(input_ids)
        pad_len=self.max_seq_len-input_len
        input_ids+=[self.pad_token_id]*pad_len
        attention_mask+=[0]*pad_len
        
        return {
            'seq_len': input_len,
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }
    
    def get_mbs(self):
        sentences=[]
        yreg=[]; ybin=[];sigmas=[]
        for i in range(CONFIG.train_sample_bins):
            if i not in self.qmap:
                continue
            yids=self.qmap[i][-2:]
            
            sentences+=list(self.df[self.df.id.isin(yids)].excerpt.values)
            yreg+=list(self.df[self.df.id.isin(yids)].target.values)
            ybin+=list(self.df[self.df.id.isin(yids)].ybin.values)
            sigmas+=list( self.df[self.df.id.isin(yids)].standard_error.values )
            
            self.qmap[i]=self.qmap[i][:-2]
            if len(self.qmap[i]) == 0:
                self.qmap.pop(i)
        
        num_samples=len(yreg)
        self.batch_fraction=len(yreg)/self.batch_size
        features={
            'seq_len': [],
            'input_ids': [],
            'attention_mask': [],
            'yreg': [],
            'ybin':[],
            'sigmas': []
        }
        
        for i, sentence in enumerate(sentences):
            data=self.get_tokenized_features(sentence)
            
            seq_len=data['seq_len']
            input_ids=data['input_ids']
            attention_mask=data['attention_mask']
            
            features['seq_len'].append(seq_len)
            features['input_ids'].append(input_ids)
            features['attention_mask'].append(attention_mask)
            features['yreg'].append(yreg[i]+np.random.uniform(-0.1, 0.1))
            features['ybin'].append(ybin[i])
            features['sigmas'].append(sigmas[i])
            
        features['seq_len']=torch.tensor(features['seq_len'], dtype=torch.long)
        features['input_ids']=torch.tensor(features['input_ids'], dtype=torch.long)
        features['attention_mask']=torch.tensor(features['attention_mask'], dtype=torch.long)
        features['yreg']=torch.tensor(features['yreg'], dtype=torch.float32)
        features['ybin']=torch.tensor(features['ybin'], dtype=torch.float32)
        features['sigmas']=torch.tensor(features['sigmas'], dtype=torch.float32)
        return features
    
    def __iter__(self):
        while len(self.qmap)>0:
            mbs=self.get_mbs()
            if self.batch_fraction < 0.5:
                break
            yield mbs
    def __next__(self):
        for i in range(10):
            yield i

# Model

In [None]:
def freeze_roberta_layers(roberta):
    max_freeze_layer=CONFIG.FREEZE_LAYERS_START
    for n,p in roberta.named_parameters():
        if ('embedding' in n): #or ('layer' in n and int(n.split('.')[2]) <= max_freeze_layer):
            p.requires_grad=False
            
            
class TextRegressor(nn.Module):
    def __init__(self):
        super(TextRegressor, self).__init__()
        self.linear=nn.Linear(CONFIG.hidden_size, 1024)
        self.layer_norm=nn.LayerNorm(1024)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(0.5)
        self.regressor=nn.Linear(1024, 1)
        
        nn.init.uniform_(self.linear.weight, -0.02, 0.02)
        nn.init.uniform_(self.regressor.weight, -0.02, 0.02)
        
    def forward(self, x):
        x=self.linear(x)
        x=self.layer_norm(x)
        x=self.relu(x)
        x=self.dropout(x)
        x=self.regressor(x)
        return x
    
class BinEstimator(nn.Module):
    def __init__(self):
        super(BinEstimator,self).__init__()
        self.linear=nn.Linear(CONFIG.hidden_size, 1024)
        self.layer_norm=nn.LayerNorm(1024)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(0.5)
        self.logits=nn.Linear(1024, CONFIG.bins)
        
        nn.init.uniform_(self.linear.weight, -0.025, 0.025)
        nn.init.uniform_(self.logits.weight, -0.02, 0.02)
        
    def forward(self, x):    
        x=self.linear(x)
        x=self.layer_norm(x)
        x=self.relu(x)
        x=self.dropout(x)
        x=self.logits(x)
        x=torch.softmax(x, dim=1)
        return x

class AttentionHead(nn.Module):
    def __init__(self):
        super(AttentionHead, self).__init__()
        self.W=nn.Linear(CONFIG.hidden_size, CONFIG.hidden_size)
        self.V=nn.Linear(CONFIG.hidden_size, 1)
    def forward(self, x):
        attn=torch.tanh(self.W(x))
        score=self.V(attn)
        attention_weights=torch.softmax(score, dim=1)
        
        context_vector=attention_weights * x
        context_vector=torch.sum(context_vector, dim=1)
        return context_vector
    
class CommonLitModel(nn.Module):
    def __init__(self):
        super(CommonLitModel, self).__init__()
        self.roberta=AutoModel.from_pretrained(CONFIG.pretrain_path)
        
        self.attention_head=AttentionHead()
        self.dropout=nn.Dropout(0.25)
        self.layer_norm=nn.LayerNorm(CONFIG.hidden_size)

        self.regressor=TextRegressor()
        self.bin_estimator=BinEstimator()
        
        freeze_roberta_layers(self.roberta)
        
    def forward(self, input_ids, attention_mask,output_hidden_states=False):
        roberta_output=self.roberta(input_ids,
                                 attention_mask=attention_mask,
                                 output_hidden_states=True)
        
        last_hidden_state=roberta_output.last_hidden_state
        cls_pool=roberta_output.pooler_output
        
        attention_pool=self.attention_head(last_hidden_state)
        x_pool=(cls_pool+attention_pool)/2
        x_pool=self.dropout(x_pool)
        x_pool=self.layer_norm(x_pool)
        
        yhat_reg=self.regressor(x_pool)
        yhat_bin=self.bin_estimator(x_pool)
        
        return yhat_reg, yhat_bin

# Custom Loss

In [None]:
class CustomLoss:
    def __init__(self):
        self.criterion=nn.MSELoss()
        self.kl_divergence=nn.KLDivLoss(reduction='batchmean', log_target=True)
        
    def get_reg_loss(self, y, yreg, phase, iter_count):
        if phase=='val':
            reg_loss=self.criterion(yreg, y)
        else:
            reg_loss=torch.tensor(0.0, device=CONFIG.device)
            reg_loss1=torch.tensor(0.0, device=CONFIG.device)
            reg_loss2=torch.tensor(0.0, device=CONFIG.device)
            reg_loss3=torch.tensor(0.0, device=CONFIG.device)
            
            ydiff=torch.abs(yreg-y)
            
            ydiff1=ydiff[ydiff<=0.1]
            ydiff2=ydiff[(ydiff>0.1) & (ydiff<=0.5)]
            ydiff3=ydiff[(ydiff>0.5)]
            
            batch_size=len(ydiff)
            alpha1=0
            alpha2=0
            alpha3=0
            
            reg_loss=torch.tensor(0.0, device=CONFIG.device)
            if len(ydiff1)>0:
                reg_loss1+=( (ydiff1**2).mean())
                alpha1=0.1
            if len(ydiff2)>0:
                reg_loss2+=( (ydiff2**2).mean())
                alpha2=0.5
            if len(ydiff3) >0:
                alpha3=0.4
                reg_loss3+=(ydiff3**2).mean()
                
            reg_loss=(alpha1 * reg_loss1)+(alpha2 * reg_loss2)+(alpha3 * reg_loss3)
            reg_loss/=(alpha1+alpha2+alpha3)
            
        return reg_loss
    
    def get_bin_loss(self, ybin, yhat_bin, phase):
        if phase == 'val':
            ybin=ybin.view(-1, CONFIG.bins)
            yhat_bin=yhat_bin.view(-1, CONFIG.bins)
            
        yerr=torch.abs(ybin-yhat_bin)
        yerr=yerr.sum(dim=1)
        loss=yerr.mean()
        return loss
    
    def get_distribution_loss(self, y_mus, y_sigmas, yhat_mus):
        P=Normal(y_mus, y_sigmas)
        Q=Normal(yhat_mus, y_sigmas)
        
        loss=(kl_divergence(P, Q)+kl_divergence(Q, P))/2
        loss=loss.mean()
        loss=loss.to(CONFIG.device)
        return loss
    
    
    def get_consistency_kl_div_loss(self, ybin, yhat_bin, phase):
        if phase == 'val':
            ybin=ybin.view(-1, 1+CONFIG.bins)
            yhat_bin=yhat_bin.view(-1, 1+CONFIG.bins)
        
        loss1=self.kl_divergence(ybin, yhat_bin)
        loss2=self.kl_divergence(yhat_bin, ybin)
        loss = (loss1+loss2)/2
        return loss
    
    def get_consistency_loss(self, sigmas, yhat_mus, yhat_bin, phase):
        num_samples=yhat_mus.size(0)
        yreg_dist=torch.zeros(num_samples, 1+CONFIG.bins)
        yreg_dist[:, 0]=Normal(yhat_mus, sigmas).cdf(qs_bins[0])
        
        for i in range(1, CONFIG.bins):
            yreg_dist[:, i]=Normal(yhat_mus, sigmas).cdf(qs_bins[i])-Normal(yhat_mus, sigmas).cdf(qs_bins[i-1])
        yreg_dist[:, CONFIG.bins]=1-Normal(yhat_mus, sigmas).cdf(qs_bins[CONFIG.bins-1])
        
        if phase=='train':
            yreg_dist=yreg_dist.to(CONFIG.device)
        loss=self.get_bin_loss(yreg_dist, yhat_bin, phase)
        #loss=self.get_consistency_kl_div_loss(yreg_dist, yhat_bin, phase)
        return loss
    
    def get_bin_cross_entropy_loss(self, ybin, yhat_bin, phase):
        if phase == 'val':
            ybin=ybin.view(-1, CONFIG.bins)
            yhat_bin=yhat_bin.view(-1, CONFIG.bins)
        loss=torch.tensor(0.0)
        loss=torch.zeros(ybin.shape)
        for i in range(ybin.shape[1]):
            loss[:, i] = (-ybin[:, i]) * torch.log(yhat_bin[:, i] + 1e-9)
        loss=loss.sum(dim=1).mean()
        return loss
    
    # This loss combines the cumulative distributions of the top-2 bins
    def get_bin_cum_loss(self, ybin, yhat_bin, phase):
        loss=0.0
        if phase == 'val':
            ybin=ybin.view(-1, CONFIG.bins)
            yhat_bin=yhat_bin.view(-1, CONFIG.bins)
        
        topk=torch.topk(ybin,k=2, dim=1)
        topk_values=topk.values.sum(dim=1)
        topk_indices=topk.indices
        batch_size=ybin.size(0)
        
        for i in range(batch_size):
            ind=topk_indices[i]
            loss+=torch.abs(yhat_bin[i][ind].sum() - topk_values[i])
        loss/=max(1, batch_size)
        return loss
        
    
    def get_loss(self, inputs, phase, iter_count=0):
        yreg=inputs['yreg']
        ybin=inputs['ybin']
        
        yhat_reg=inputs['yhat_reg']
        yhat_bin=inputs['yhat_bin']
        
        reg_loss=self.get_reg_loss(yreg, yhat_reg, phase, iter_count)
        loss=reg_loss#+0.4*(bin_loss+bin_cum_loss)
        return {
            'loss': loss,
            'reg_loss': reg_loss.item(),
            'bin_loss': 0,
            'bin_cum_loss': 0
        }

# Evaluator

In [None]:
class CustomEvaluator:
    def __init__(self, val_dataloader):
        self.val_dataloader=val_dataloader
        self.criterion=nn.MSELoss()
        self.custom_loss=CustomLoss()
        
    def evaluate(self, model):
        model.eval()
        all_yreg=[]
        all_ybin=[]
        
        all_yhatreg=[]
        all_yhatbin=[]
        
        for batch in self.val_dataloader:
            batch_max_seq_len=torch.max(batch['seq_len'])
            
            input_ids=batch['input_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            attention_mask=batch['attention_mask'][:, :batch_max_seq_len].to(CONFIG.device)
            yreg=batch['yreg'].view(-1)
            ybin=batch['ybin'].view(-1)
            
            
            all_yreg+=yreg.tolist()
            all_ybin+=ybin.tolist()
            
            with torch.no_grad():
                yhat_reg, yhat_bin=model(input_ids, attention_mask)
                yhat_reg=yhat_reg.view(-1).detach().cpu()
                yhat_bin=yhat_bin.view(-1).detach().cpu()
                
                all_yhatreg+=yhat_reg.tolist()
                all_yhatbin+=yhat_bin.tolist()
                
        all_yreg=torch.tensor(all_yreg, dtype=torch.float32)
        all_ybin=torch.tensor(all_ybin, dtype=torch.float32)
        
        all_yhatreg=torch.tensor(all_yhatreg, dtype=torch.float32)
        all_yhatbin=torch.tensor(all_yhatbin, dtype=torch.float32)
        
        ydiff=torch.abs(all_yhatreg - all_yreg).numpy()
        
        print("ydiff Variance:", np.std(ydiff))
        
        print('Quantiles---')
        print(np.quantile(ydiff, 0.7))
        print(np.quantile(ydiff, 0.8))
        print(np.quantile(ydiff, 0.9))
        print(np.quantile(ydiff, 0.95))
        model_losses=self.custom_loss.get_loss({
            'yreg': all_yreg,
            'ybin': all_ybin,
            
            'yhat_reg': all_yhatreg,
            'yhat_bin': all_yhatbin
        }, 'val', 0)
        return model_losses

# Trainer

def get_optimizer_params(model):
    optimizer_parameters=[
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and ('roberta' in n) and ('LayerNorm' not in n) )],
        },
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and ('roberta' in n) and 'LayerNorm' in n )],
            'weight_decay':0 
        },
        
        
        
        
        
        
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and 'roberta' not in n)],
            'lr': 1e-3
        }
    ]
    return optimizer_parameters

In [None]:
def get_optimizer_params(model):
    optimizer_parameters=[
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and ('roberta' in n) and ('LayerNorm' not in n) )],
        },
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and ('roberta' in n) and 'LayerNorm' in n )],
            'weight_decay':0 
        },
        {
            'params': [p for n, p in model.named_parameters() if (p.requires_grad and 'roberta' not in n)],
            'lr': 1e-3
        }
    ]
    return optimizer_parameters

https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-fit

def get_optimizer_params(model):
    no_decay=['LayerNorm', 'bias', 'beta', 'gamma']
    group1=['layer.8', 'layer.9', 'layer.10', 'layer.11']
    group2=['layer.4', 'layer.5', 'layer.6', 'layer.7']
    group3=['layer.1', 'layer.2', 'layer.3']
    
    lr_group1=5e-5
    lr_group2=3e-5
    lr_group3=2e-5
    
    optimizer_parameters=[
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group1) and 
                       any([n not in nd for nd in no_decay]) 
                      ],
            'lr': lr_group1,
            'weight_decay': 0.01
        },
        
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group1) and 
                       any([n in nd for nd in no_decay]) 
                      ],
            'lr': lr_group1,
            'weight_decay': 0
        },
        
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group2) and 
                       any([n not in nd for nd in no_decay]) 
                      ],
            'lr': lr_group2,
            'weight_decay': 0.01
        },
        
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group2) and 
                       any([n in nd for nd in no_decay]) 
                      ],
            'lr': lr_group2,
            'weight_decay': 0.0
        },
        
        #Group-3
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group3) and 
                       any([n not in nd for nd in no_decay]) 
                      ],
            'lr': lr_group3,
            'weight_decay': 0.01
        },
        
        {
            'params': [p for n, p in model.roberta.named_parameters() if (p.requires_grad) and 
                       (n in group3) and 
                       any([n in nd for nd in no_decay]) 
                      ],
            'lr': lr_group3,
            'weight_decay': 0.01
        },
        
        {
            'params': [p for n, p in model.named_parameters() if ('roberta' not in n) and any([n not in nd for nd in no_decay])],
            'lr': 1e-3,
            'weight_decay': 0.01
        },
        {
            'params': [p for n, p in model.named_parameters() if ('roberta' not in n) and any([n in nd for nd in no_decay])],
            'lr': 1e-3,
            'weight_decay': 0.0
        }
    ]
    return optimizer_parameters

In [None]:
class Trainer:
    def __init__(self, model, fold_train_df,  val_dataloader):
        self.df=fold_train_df.copy()
        self.val_dataloader=val_dataloader
        
        self.model=model
        self.optimizer=AdamW(get_optimizer_params(model),
                             lr=CONFIG.learning_rate,
                             weight_decay=CONFIG.weight_decay)
        
        self.schedular=torch.optim.lr_scheduler.OneCycleLR(self.optimizer,
                                                           max_lr=CONFIG.learning_rate,
                                                           total_steps=CONFIG.TRAIN_MAX_ITERS,
                                                           pct_start=0.25
                                                          )
                                              #CONFIG.TRAIN_WARMUP_STEPS,
                                              #CONFIG.TRAIN_MAX_ITERS)
        
        self.custom_loss=CustomLoss()
        self.custom_evaluator=CustomEvaluator(val_dataloader)
        
        
        self.train_loss=[]
        self.train_reg_loss=[]
        self.train_bin_loss=[]
        self.train_bin_cum_loss=[]
        
        self.val_loss=[]
        self.val_reg_loss=[]
        self.val_bin_loss=[]
        self.val_bin_cum_loss=[]
        
        self.iter_count=0
        self.best_iter=0
        self.best_reg_iter=0
        self.best_bin_iter=0
        
        self.best_loss=None
        self.best_reg_loss=None
        self.best_bin_loss=None
        
        
    def checkpoint(self, model_losses):
        val_loss=model_losses['loss'].item()
        val_reg_loss=model_losses['reg_loss']
        val_bin_loss=model_losses['bin_loss']
        
        if (self.best_loss is None) or (self.best_loss > val_loss):
            self.best_loss=val_loss
            self.best_iter=self.iter_count
            torch.save(self.model, "best_model.pt")
        
        if (self.best_reg_loss is None) or (self.best_reg_loss > val_reg_loss):
            self.best_reg_loss=val_reg_loss
            self.best_reg_iter=self.iter_count
            torch.save(self.model, "best_reg_model.pt")
            
        if (self.best_bin_loss is None) or (self.best_bin_loss > val_bin_loss):
            self.best_bin_loss=val_bin_loss
            self.best_bin_iter=self.iter_count
            torch.save(self.model, "best_bin_model.pt")

        print("==="*10)
        print("Iter:{} | BestIter:{} | Best Reg Iter:{} | Best Bin Iter:{} ".format(
            self.iter_count,self.best_iter, self.best_reg_iter, self.best_bin_iter
        ))
        
        print("Training Losses:")
        print("Total: {:.3f} | Reg Loss:{:.3f} | Bin Loss:{:.3f} | Bin cumloss:{:.3f}".format(
            self.train_loss[-1], self.train_reg_loss[-1], self.train_bin_loss[-1],
            self.train_bin_cum_loss[-1]
        ))
        print()
        print("Val Losses")
        print("Total: {:.3f} | Reg Loss:{:.3f} | Bin Loss:{:.3f} | Bin cumloss:{:.3f}".format(
            val_loss, val_reg_loss, val_bin_loss,
            self.val_bin_cum_loss[-1]
        ))
    
    def train_ops(self, inputs):
        self.optimizer.zero_grad()
        model_losses=self.custom_loss.get_loss(inputs, 'train', self.iter_count)
        model_losses['loss'].backward()
        self.optimizer.step()
        self.schedular.step()
        return model_losses
    
    def train_epoch(self):
        t1=time.time()
        self.model.train()
        for batch in TrainDataSampler(CONFIG.TRAIN_SAMPLES_PER_BATCH, self.df):
            self.iter_count+=1
            if self.iter_count > CONFIG.TRAIN_MAX_ITERS:
                break
            
            batch_seq_lens=batch['seq_len']
            batch_max_seq_len=torch.max(batch['seq_len'])
            
            input_ids=batch['input_ids'][:, :batch_max_seq_len].to(CONFIG.device)
            attention_mask=batch['attention_mask'][:, :batch_max_seq_len].to(CONFIG.device)
            yreg=batch['yreg'].to(CONFIG.device)
            ybin=batch['ybin'].to(CONFIG.device)
            
            yhat_reg, yhat_bin=self.model(input_ids, attention_mask)
            yhat_reg=yhat_reg.view(-1)
            
            model_losses=self.train_ops({
                'yreg': yreg,
                'ybin': ybin,
                
                'yhat_reg': yhat_reg,
                'yhat_bin': yhat_bin
            })
            
            self.train_loss.append(model_losses['loss'].item())
            self.train_reg_loss.append(model_losses['reg_loss'])
            self.train_bin_loss.append(model_losses['bin_loss'])
            self.train_bin_cum_loss.append(model_losses['bin_cum_loss'])
            
            
            if self.iter_count%CONFIG.eval_every==0:
                model_losses=self.custom_evaluator.evaluate(model)
                self.val_loss.append(model_losses['loss'].item())
                self.val_reg_loss.append(model_losses['reg_loss'])
                self.val_bin_loss.append(model_losses['bin_loss'])
                self.val_bin_cum_loss.append(model_losses['bin_cum_loss'])
                self.checkpoint(model_losses)

            
    def train(self):
        while True:
            if self.iter_count > CONFIG.TRAIN_MAX_ITERS:
                break
            self.train_epoch()

In [None]:
for k in range(CONFIG.folds):
    print("==="*10)
    print()
    print("Fold: ==> ", k+1)
    fold_train_df=train_df[train_df.kfold!=k].copy()
    fold_val_df  =train_df[train_df.kfold==k].copy()

    if CONFIG.env=='test':
        fold_train_df=fold_train_df.head(4)
        fold_val_df=fold_val_df.head(4)
        
    val_dataset=CommonLitDataset(fold_val_df)
    val_dataloader=torch.utils.data.DataLoader(val_dataset, batch_size=CONFIG.batch_size,
                                                 shuffle=False, pin_memory=True, drop_last=False)
    
    
    model=CommonLitModel()
    model=model.to(CONFIG.device)
    trainer=Trainer(model, fold_train_df, val_dataloader)
    trainer.train()
        
    best_model=torch.load('best_model.pt')
    best_reg_model=torch.load('best_reg_model.pt')
    best_bin_model=torch.load('best_bin_model.pt')
    
    torch.save(best_model, "best_model{}.pt".format(k+1))
    torch.save(best_reg_model, "best_reg_model{}.pt".format(k+1))
    torch.save(best_bin_model, "best_bin_model{}.pt".format(k+1))
    
    print("Best Iteration:", trainer.best_iter)
    print("Best Reg Iteration:", trainer.best_reg_iter)
    print("Best Bin Iteration:", trainer.best_bin_iter)
    
    
    print("Best Loss:{}".format(trainer.best_loss))
    print("Best Reg Loss:{}".format(trainer.best_reg_loss))
    print("Best Bin Loss:{}".format(trainer.best_bin_loss))

In [None]:
plt.plot(trainer.train_loss)

In [None]:
plt.plot(trainer.train_reg_loss)

In [None]:
plt.plot(trainer.train_bin_loss)

In [None]:
plt.plot(trainer.train_bin_cum_loss)

In [None]:
plt.plot(trainer.val_loss)

In [None]:
plt.plot(trainer.val_reg_loss)

In [None]:
plt.plot(trainer.val_bin_loss)


In [None]:
plt.plot(trainer.val_bin_cum_loss)

# Inference

In [None]:
train_df=pd.read_csv('../input/commonlit-kfold-dataset/fold_train.csv')
train_df['q']=train_df.target.apply(get_quantile, args=(train_qs, ))
train_df['ybin']=train_df.apply(get_bin_distribution, args=(bin_ranges, ), axis=1)

if CONFIG.env=='test':
    train_df=train_df.head(5)

In [None]:
batch_size=32

test_dataset=CommonLitDataset(train_df)
test_dataloader=torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                            shuffle=False, 
                                            pin_memory=True, drop_last=False)

print(len(test_dataloader))

In [None]:
models=[
    torch.load('./best_reg_model1.pt')
]



ypreds=[]
ypred_bins=[]#np.zeros(len(train_df), CONFIG.bins)

for batch in test_dataloader:
    batch_seq_lens=batch['seq_len']
    batch_max_seq_len=torch.max(batch['seq_len'])

    input_ids=batch['input_ids'][:, :batch_max_seq_len].to(CONFIG.device)
    attention_mask=batch['attention_mask'][:, :batch_max_seq_len].to(CONFIG.device)
    batch_size=input_ids.size(0)
    
    with torch.no_grad():
        batch_yhat=np.zeros(batch_size)
        batch_yhat_bin=np.zeros( (batch_size, CONFIG.bins))
        for model in models:
            model.eval()
            yhat, yhat_bin=model(input_ids, attention_mask)
            yhat=yhat.view(-1).detach().cpu()
            yhat_bin=yhat_bin.detach().cpu().numpy()
            
            batch_yhat+=yhat.numpy()
            batch_yhat_bin+=yhat_bin
            
        batch_yhat/=len(models)
        batch_yhat_bin/=len(models)
        
        ypreds+=batch_yhat.tolist()
        ypred_bins+=batch_yhat_bin.tolist()

In [None]:
train_df['yhat_target']=ypreds
train_df['yhat_bins']=ypred_bins

train_df.head()

In [None]:
train_df.to_csv('Train Inference.csv', index=False)