In [None]:
ENV = 'kaggle'
assert ENV in ['colab', 'kaggle']
 
PHASE = 'train'
assert PHASE in ['train', 'eval_oof', 'inference']

In [None]:
# !pip install transformers==4.5.1
!pip install torch==1.9.0

In [None]:
!nvidia-smi

In [None]:
if ENV=='colab':
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
import os
import math
import random
import time
 
import numpy as np
import pandas as pd
 
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
 
import transformers
from transformers import AdamW
from transformers import AutoTokenizer
from transformers import AutoModel
from transformers import AutoConfig
from transformers import get_cosine_schedule_with_warmup
 
from sklearn.model_selection import KFold
 
import gc, json, pickle, shutil
gc.enable()

from tqdm.auto import tqdm
from matplotlib import pyplot as plt

In [None]:
print(transformers.__version__)

In [None]:
print(torch.__version__)

In [None]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    os.environ["PYTHONHASHSEED"] = str(random_seed)

    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

    torch.backends.cudnn.deterministic = True

# Dataset

In [None]:
class LitDataset(Dataset):
    def __init__(self, df, inference_only=False):
        super().__init__()

        self.df = df        
        self.inference_only = inference_only
        self.text = df.excerpt.tolist()
        #self.text = [text.replace("\n", " ") for text in self.text]
        
        if not self.inference_only:
            self.target = torch.tensor(df.target.values, dtype=torch.float32)
            self.bins = torch.tensor(df.bins.values, dtype=torch.long)
    
        self.encoded = tokenizer.batch_encode_plus(
            self.text,
            padding = 'max_length',            
            max_length = MAX_LEN,
            truncation = True,
            return_attention_mask = True,
            return_token_type_ids = not NO_TOKEN_TYPE
        )

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):        
        input_ids = torch.tensor(self.encoded['input_ids'][index])
        attention_mask = torch.tensor(self.encoded['attention_mask'][index])
        
        out_dict = {'input_ids':input_ids, 'attention_mask':attention_mask}
        
        if not NO_TOKEN_TYPE:
            out_dict['token_type_ids'] = torch.tensor(self.encoded['token_type_ids'][index])
        
        if sa_complex is not None:
            if sa_complex == 'hdd':
                with open(f'SelfAttComplex/{str(index).zfill(4)}.pkl','rb') as f:
                    out_dict['sa_complex'] = pickle.load(f)
            else:
                out_dict['sa_complex'] = sa_complex[index]

        if not self.inference_only:
            out_dict['target'] = self.target[index]
            out_dict['bins'] = self.bins[index]

        return out_dict

# Self Attention Complexity in Pretrained Model

In [None]:
def SelfAttention_Complexity(df: pd.DataFrame, output_device):
    pre_dataset = LitDataset(df, inference_only=True)
    pre_loader = DataLoader(pre_dataset, batch_size=BATCH_SIZE,
                            drop_last=False, shuffle=False)
    
    if output_device == 'hdd':
        os.makedirs('SelfAttComplex', exist_ok=True)

    cfg_update = {"output_attentions":True, "hidden_dropout_prob": 0.0,
                  "layer_norm_eps": 1e-7}
    if PHASE=='train':
        config = AutoConfig.from_pretrained(MODEL_NAME)
        config.update(cfg_update)
        backbone = AutoModel.from_pretrained(MODEL_NAME, config=config).to(DEVICE)
    elif PHASE=='eval_oof' or PHASE=='inference':
        config = AutoConfig.from_pretrained(LOAD_BACKBONE_DIR)
        config.update(cfg_update)
        backbone = AutoModel.from_pretrained(LOAD_BACKBONE_DIR, config=config).to(DEVICE)

    backbone.resize_token_embeddings(len(tokenizer))

    output_sa_complex = []
    backbone.eval()
    idx = 0
    with torch.no_grad():
        for batch_num, dsargs in enumerate(tqdm(pre_loader)):

            kwargs = {}
            kwargs['input_ids'] = dsargs['input_ids'].to(DEVICE)
            if not NO_TOKEN_TYPE:
                kwargs['token_type_ids'] = dsargs['token_type_ids'].to(DEVICE)
            kwargs['attention_mask'] = dsargs['attention_mask'].to(DEVICE)

            if 't5' in MODEL_NAME.lower() and HAS_DECODER:
                # shift to right
                kwargs['decoder_input_ids'] = torch.cat([tokenizer.pad_token_id * torch.ones(kwargs['input_ids'].size(0), 1).long().to(DEVICE),
                                                        kwargs['input_ids'][:,:-1]], dim=1)
            
            # self attention
            output_backbone = backbone(**kwargs)
            self_att = torch.stack(output_backbone.attentions, dim=1) #[batch, layer, head, seq, seq]
            seq_len = self_att.size(-1)
            self_att = self_att.view(self_att.size(0), -1, seq_len, seq_len) #[batch, layer*head, seq, seq]
            self_att *= kwargs['attention_mask'].unsqueeze(1).unsqueeze(-1)

            # self attention complexity
            distance_from_diag = (torch.arange(seq_len).view(1, -1) - torch.arange(seq_len).view(-1, 1)) / (seq_len - 1)
            distance_from_diag = distance_from_diag.to(DEVICE)
            sa_complex = []
            temp = self_att * distance_from_diag.unsqueeze(0).unsqueeze(1).clip(min=0)
            temp = temp.sum(dim=-1) #[batch, layer*head, seq]
            sa_complex.append(temp)
            temp = self_att * distance_from_diag.unsqueeze(0).unsqueeze(1).clip(max=0).abs()
            temp = temp.sum(dim=-1) #[batch, layer*head, seq]
            sa_complex.append(temp)
            sa_complex = torch.cat(sa_complex, dim=1).transpose(-2,-1) #[batch, seq, layer*head*2]

            if output_device == 'hdd':
                for batch_item in sa_complex:
                    with open(f'SelfAttComplex/{str(idx).zfill(4)}.pkl','wb') as f:
                        pickle.dump(batch_item, f)
                    idx += 1
            else:
                output_sa_complex.append(sa_complex)
    
    if output_device == 'hdd':
        return 'hdd'
    else:
        output_sa_complex = torch.cat(output_sa_complex, dim=0)
        return output_sa_complex.to(output_device)

# Model
The model is inspired by the one from [Maunish](https://www.kaggle.com/maunish/clrp-roberta-svm).

In [None]:
class LitModel(nn.Module):
    def __init__(self, benchmark_token=None, use_max_pooling=False, sa_complex_dim=0):
        super().__init__()
 
        self.benchmark_token = benchmark_token
        self.use_max_pooling = use_max_pooling
        self.sa_complex_dim = sa_complex_dim
        
        cfg_update = {"output_hidden_states":True, "hidden_dropout_prob": 0.0,
                      "layer_norm_eps": 1e-7}
        if PHASE=='train':
            config = AutoConfig.from_pretrained(MODEL_NAME)
            config.save_pretrained(f'{SAVE_DIR}/backbone')
            config.update(cfg_update)                       
            self.backbone = AutoModel.from_pretrained(MODEL_NAME, config=config)
            self.backbone.save_pretrained(f'{SAVE_DIR}/backbone')
        elif PHASE=='eval_oof' or PHASE=='inference':
            config = AutoConfig.from_pretrained(LOAD_BACKBONE_DIR)
            config.update(cfg_update)                       
            self.backbone = AutoModel.from_pretrained(LOAD_BACKBONE_DIR, config=config)
            
        self.hidden_layer_weights = nn.Parameter(torch.zeros(NUM_HIDDEN_LAYERS).view(-1, 1, 1, 1))
 
        # Dropout layers
        self.dropouts_regr = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
        self.dropouts_clsi = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
 
        if self.use_max_pooling:
            num_pool = 2
        else:
            num_pool = 1
        self.attention_layer_norm = nn.LayerNorm(HIDDEN_SIZE * num_pool + sa_complex_dim)
        self.attention = nn.Sequential(            
            nn.Linear(HIDDEN_SIZE * num_pool + sa_complex_dim, 512 * num_pool),            
            nn.Tanh(),                       
            nn.Linear(512 * num_pool, 1),
            nn.Softmax(dim=1)
            )        
        self.head_regressor = nn.Linear(HIDDEN_SIZE * num_pool + sa_complex_dim, 1)
        self.head_classifier = nn.Linear(HIDDEN_SIZE * num_pool + sa_complex_dim, NUM_BINS)                   
 
    def forward(self, input_ids, token_type_ids, attention_mask, self_att_complex):

        kwargs = {}
        if self.benchmark_token is None:
            kwargs['input_ids'] = input_ids
            if not NO_TOKEN_TYPE:
                kwargs['token_type_ids'] = token_type_ids
            kwargs['attention_mask'] = attention_mask
        else:
            benchmark_input_ids, benchmark_token_type_ids, benchmark_attention_mask = self.benchmark_token
            kwargs['input_ids'] = torch.cat((input_ids, benchmark_input_ids), dim = 0)
            if not NO_TOKEN_TYPE:
                kwargs['token_type_ids'] = torch.cat((token_type_ids, benchmark_token_type_ids), dim = 0)
            kwargs['attention_mask'] = torch.cat((attention_mask, benchmark_attention_mask), dim = 0)

        if 't5' in MODEL_NAME.lower() and HAS_DECODER:
            # shift to right
            kwargs['decoder_input_ids'] = torch.cat([tokenizer.pad_token_id * torch.ones(kwargs['input_ids'].size(0), 1).long().to(DEVICE),
                                                     kwargs['input_ids'][:,:-1]], dim=1)
        output_backbone = self.backbone(**kwargs)
        
        # Extract output
        if HAS_DECODER:
            hidden_states = output_backbone.encoder_hidden_states + output_backbone.decoder_hidden_states[1:]
        else:
            hidden_states = output_backbone.hidden_states
 
        # Mean/max pooling (over hidden layers), concatenate with pooler
        hidden_states = torch.stack(tuple(hidden_states[-i-1] for i in range(len(hidden_states) - 1)), dim = 0)
        layer_weight = nn.functional.softmax(self.hidden_layer_weights, dim = 0)
        output_backbone = torch.sum(hidden_states * layer_weight, dim = 0)
        if self.use_max_pooling:
            out_max, _ = torch.max(hidden_states, dim = 0)
            output_backbone = torch.cat((output_backbone, out_max), dim = -1)
        if self.sa_complex_dim != 0:
            self_att_complex = torch.cat((self_att_complex, benchmark_sa_complex), dim = 0)
            output_backbone = torch.cat((output_backbone, self_att_complex), dim = -1)
        
        output_backbone = self.attention_layer_norm(output_backbone)
 
        # Attention Pooling
        weights = self.attention(output_backbone)
        context_vector = torch.sum(weights * output_backbone, dim=1)        
 
        # Multiple dropout
        for i, dropout in enumerate(self.dropouts_regr):
            if i == 0:
                output_regr = self.head_regressor(dropout(context_vector))
                output_clsi = self.head_classifier(self.dropouts_clsi[i](context_vector))
            else:
                output_regr += self.head_regressor(dropout(context_vector))
                output_clsi += self.head_classifier(self.dropouts_clsi[i](context_vector))
 
        output_regr /= len(self.dropouts_regr)
        output_clsi /= len(self.dropouts_clsi)

        if self.benchmark_token is not None:
            output_regr = output_regr[:-1] - output_regr[-1]
            output_clsi = output_clsi[:-1]

        # Now we reduce the context vector to the prediction score.
        return output_regr, nn.functional.softmax(output_clsi, dim=-1)

# loss function

In [None]:
class QuadraticWeightedKappaLoss(nn.Module):
    def __init__(self, num_cat, device = 'cpu'):
        super(QuadraticWeightedKappaLoss, self).__init__()
        self.num_cat = num_cat
        cats = torch.arange(num_cat).to(device)
        self.weights = (cats.view(-1,1) - cats.view(1,-1)).pow(2) / (num_cat - 1)**2
        
    def _confusion_matrix(self, pred_smax, true_cat):
        confusion_matrix = torch.zeros((self.num_cat, self.num_cat)).to(pred_smax.device)
        for t, p in zip(true_cat.view(-1), pred_smax):
            confusion_matrix[t.long()] += p
        return confusion_matrix
        
    def forward(self, pred_smax, true_cat):
        # Confusion matrix
        O = self._confusion_matrix(pred_smax, true_cat)
        
        # Count elements in each category
        true_hist = torch.bincount(true_cat, minlength = self.num_cat)
        pred_hist = pred_smax.sum(dim = 0)
        
        # Expected values
        E = torch.outer(true_hist, pred_hist)
        
        # Normlization
        O = O / torch.sum(O)
        E = E / torch.sum(E)
        
        # Weighted Kappa
        numerator = torch.sum(self.weights * O)
        denominator = torch.sum(self.weights * E)
        
        return COEF_QWK * numerator / denominator

In [None]:
class BradleyTerryLoss(nn.Module):
    def __init__(self):
        super(BradleyTerryLoss, self).__init__()

    def forward(self, pred_mean, true_mean):
        batch_size = len(pred_mean)
        true_comparison = true_mean.view(-1,1) - true_mean.view(1,-1)
        pred_comparison = pred_mean.view(-1,1) - pred_mean.view(1,-1)
        
        return COEF_BT * (torch.log(1 + torch.tril(torch.exp(-true_comparison * pred_comparison))).sum()
                          / (batch_size * (batch_size - 1) / 2))

In [None]:
def eval_mse(model, data_loader):
    """Evaluates the mean squared error of the |model| on |data_loader|"""
    model.eval()            
    mse_sum = 0

    all_pred_r = []
    with torch.no_grad():
        for batch_num, dsargs in enumerate(data_loader):
            input_ids = dsargs['input_ids'].to(DEVICE)
            attention_mask = dsargs['attention_mask'].to(DEVICE)
            target = dsargs['target'].to(DEVICE)
            bins = dsargs['bins'].to(DEVICE)

            token_type_ids = None
            if not NO_TOKEN_TYPE:
                token_type_ids = dsargs['token_type_ids'].to(DEVICE)

            self_att_complex = None
            if USE_SELF_ATT:
                self_att_complex = dsargs['sa_complex'].to(DEVICE)

            pred_r, _ = model(input_ids, token_type_ids, attention_mask, self_att_complex)                       

            mse_sum += nn.MSELoss(reduction="sum")(pred_r.flatten(), target).item()
            all_pred_r.append(pred_r)

    return mse_sum / len(data_loader.dataset), torch.cat(all_pred_r, dim=0).squeeze()

# Training, Validation

In [None]:
def train(model, model_path, train_loader, val_loader,
          optimizer, num_epochs, fold, scheduler=None):    
    best_val_rmse = None
    best_epoch = 0
    step = 0
    last_eval_step = 0
    eval_period = EVAL_SCHEDULE[0][1]    

    start = time.time()

    history = {'step':[], 'epoch':[], 'batch_num':[], 'val_rmse':[],
               'trn_rmse':[], 'trn_qwk':[], 'trn_bt':[]}
    
    for epoch in range(num_epochs):
        val_rmse = None         

        epoch_target, epoch_bins, epoch_pred_r, epoch_pred_c = (torch.tensor([]),)*4
        epoch_bins = epoch_bins.long()
    
        for batch_num, dsargs in enumerate(train_loader):
            input_ids = dsargs['input_ids'].to(DEVICE)
            attention_mask = dsargs['attention_mask'].to(DEVICE)
            target = dsargs['target'].to(DEVICE)
            bins = dsargs['bins'].to(DEVICE)

            token_type_ids = None
            if not NO_TOKEN_TYPE:
                token_type_ids = dsargs['token_type_ids'].to(DEVICE)

            self_att_complex = None
            if USE_SELF_ATT:
                self_att_complex = dsargs['sa_complex'].to(DEVICE)

            optimizer.zero_grad()
            
            model.train()

            pred_r, pred_c = model(input_ids, token_type_ids, attention_mask, self_att_complex)
                                                        
            loss = (nn.MSELoss(reduction="mean")(pred_r.flatten(), target)
                    + QWKloss(pred_c, bins) + BTloss(pred_r.flatten(), target))
                        
            loss.backward()
            
            epoch_target = torch.cat([epoch_target.to(DEVICE), target.clone().detach()], dim=0)
            epoch_bins = torch.cat([epoch_bins.to(DEVICE), bins.clone().detach()], dim=0)
            epoch_pred_r = torch.cat([epoch_pred_r.to(DEVICE), pred_r.clone().detach()], dim=0)
            epoch_pred_c = torch.cat([epoch_pred_c.to(DEVICE), pred_c.clone().detach()], dim=0)

            optimizer.step()
            if scheduler:
                scheduler.step()
            
            if step >= last_eval_step + eval_period:
                # Evaluate the model on val_loader.
                elapsed_seconds = time.time() - start
                num_steps = step - last_eval_step
                print(f"\n{num_steps} steps took {elapsed_seconds:0.3} seconds")
                last_eval_step = step
                
                mse, _ = eval_mse(model, val_loader)
                val_rmse = math.sqrt(mse)
                trn_rmse = nn.MSELoss(reduction="mean")(epoch_pred_r.flatten(), epoch_target).item()
                trn_qwk  = QWKloss(epoch_pred_c, epoch_bins).item()
                trn_bt  = BTloss(epoch_pred_r.flatten(), epoch_target).item()

                print(f"Epoch: {epoch} batch_num: {batch_num}", 
                      f"val_rmse: {val_rmse:0.4}", f"train_rmse: {trn_rmse:0.4}",
                      f"train_qwk: {trn_qwk:0.4}", f"train_bt: {trn_bt:0.4}")

                for rmse, period in EVAL_SCHEDULE:
                    if val_rmse >= rmse:
                        eval_period = period
                        break
                percent = step / (num_epochs * len(train_loader))
                if 0.5 <= percent and percent <= 0.8:
                    eval_period = min([eval_period, 8])
                
                if not best_val_rmse or val_rmse < best_val_rmse:                    
                    best_val_rmse = val_rmse
                    best_epoch = epoch
                    torch.save(model.state_dict(), model_path)
                    print(f"New best_val_rmse: {best_val_rmse:0.4}")
                else:       
                    print(f"Still best_val_rmse: {best_val_rmse:0.4}",
                          f"(from epoch {best_epoch})")

                ''' history json dump '''
                history['step'].append(step)
                history['epoch'].append(epoch)
                history['batch_num'].append(batch_num)
                history['val_rmse'].append(val_rmse)
                history['trn_rmse'].append(trn_rmse)
                history['trn_qwk'].append(trn_qwk)
                history['trn_bt'].append(trn_bt)
                with open(f'{SAVE_DIR}/{MODEL_VER}_fold{fold+1}_history.json', 'w') as f:
                    json.dump(history, f, indent=4)
                    
                start = time.time()
                                            
            step += 1

        del epoch_target, epoch_bins, epoch_pred_r, epoch_pred_c
        
        print('\nHidden Layer Weights:')
        print(model.hidden_layer_weights.squeeze())
        print(nn.functional.softmax(model.hidden_layer_weights.squeeze(),dim=0))
    
    return best_val_rmse

In [None]:
def predict(model, data_loader):
    """Returns an np.array with predictions of the |model| on |data_loader|"""
    model.eval()

    result = np.zeros(len(data_loader.dataset))    
    index = 0
    
    with torch.no_grad():
        for batch_num, dsargs in enumerate(data_loader):
            input_ids = dsargs['input_ids'].to(DEVICE)
            attention_mask = dsargs['attention_mask'].to(DEVICE)

            token_type_ids = None
            if not NO_TOKEN_TYPE:
                token_type_ids = dsargs['token_type_ids'].to(DEVICE)

            self_att_complex = None
            if USE_SELF_ATT:
                self_att_complex = dsargs['sa_complex'].to(DEVICE)
                        
            pred_r, _ = model(input_ids, token_type_ids, attention_mask, self_att_complex)                        

            result[index : index + pred_r.shape[0]] = pred_r.flatten().to("cpu")
            index += pred_r.shape[0]

    return result

In [None]:
def create_optimizer(model):
    named_parameters = list(model.named_parameters())
    
    backbone_parameters = [(n, p) for n, p in named_parameters if n.startswith('backbone')]
    attention_parameters = [(n, p) for n, p in named_parameters if n.startswith('attention')]
    hidden_wts_parameters = [(n, p) for n, p in named_parameters if n.startswith ('hidden_layer_weights')]
    head_parameters = [(n, p) for n, p in named_parameters if n.startswith('head')]
        
    attention_group = [params for (name, params) in attention_parameters]
    hidden_wts_group = [params for (name, params) in hidden_wts_parameters]
    head_group = [params for (name, params) in head_parameters]
 
    parameters = []
    parameters.append({"params": attention_group})
    parameters.append({"params": hidden_wts_group, 'weight_decay': 0.0, 'lr': HIDDEN_WTS_LR})
    parameters.append({"params": head_group})
 
    no_decay = ['bias', 'LayerNorm.weight', 'layer_norm']
 
    if 'roberta' in MODEL_NAME.lower() or 'electra' in MODEL_NAME.lower():
        layers = [getattr(model, 'backbone').embeddings] + list(getattr(model, 'backbone').encoder.layer)
    elif 'gpt2' in MODEL_NAME.lower():
        layers = [getattr(model, 'backbone').wte] + list(getattr(model, 'backbone').h)
    elif 'xlnet' in MODEL_NAME.lower():
        layers = [getattr(model, 'backbone').word_embedding] + list(getattr(model, 'backbone').layer)
    elif 'bart' in MODEL_NAME.lower():
        enc_layers = ([getattr(model, 'backbone').encoder.embed_positions] +
                      list(getattr(model, 'backbone').encoder.layers) +
                      [getattr(model, 'backbone').encoder.layernorm_embedding])
        dec_layers = ([getattr(model, 'backbone').decoder.embed_positions] +
                      list(getattr(model, 'backbone').decoder.layers) + 
                      [getattr(model, 'backbone').decoder.layernorm_embedding])
        assert len(enc_layers)==len(dec_layers)
        layers = [getattr(model, 'backbone').shared]
        for e, d in zip(enc_layers, dec_layers):
            layers += [e, d]
    elif 't5' in MODEL_NAME.lower():
        enc_layers = (list(getattr(model, 'backbone').encoder.block) +
                      [getattr(model, 'backbone').encoder.final_layer_norm])
        dec_layers = (list(getattr(model, 'backbone').decoder.block) + 
                      [getattr(model, 'backbone').decoder.final_layer_norm])
        assert len(enc_layers)==len(dec_layers)
        layers = [getattr(model, 'backbone').shared]
        for e, d in zip(enc_layers, dec_layers):
            layers += [e, d]
    else:
        raise RuntimeError('specify the parameters for backbone.')
 
    layers.reverse()
    layerwise_learning_rate_decay = LAYERWISE_LR_DECAY**(1.0/len(layers))
    lr = BACKBONE_LR
    for i, layer in enumerate(layers):
        lr *= layerwise_learning_rate_decay
        parameters += [
            {
                'params': [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01,
                'lr': lr,
            },
            {
                'params': [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                'lr': lr,
            },
        ]
 
    return AdamW(parameters)

In [None]:
def convert_examples_to_features(text, tokenizer, max_len, is_test = False, return_tensor = False):
    # Take from https://www.kaggle.com/rhtsingh/commonlit-readability-prize-roberta-torch-fit
    text = text.replace('\n', '')
    if return_tensor:
        tok = tokenizer.encode_plus(
            text, 
            max_length = max_len, 
            padding = 'max_length', 
            return_tensors = 'pt',
            truncation = True,
            return_attention_mask = True,
            return_token_type_ids = not NO_TOKEN_TYPE
        )
    else:
        tok = tokenizer.encode_plus(
            text, 
            max_length = max_len, 
            padding = 'max_length', 
            truncation = True,
            return_attention_mask = True,
            return_token_type_ids = not NO_TOKEN_TYPE
        )
    return tok

In [None]:
def Train_or_Validation():
    list_val_rmse = []
 
    oof = []
    for fold in range(NUM_FOLDS):
        print(f"\nFold {fold + 1}/{NUM_FOLDS}")
            
        set_random_seed(SEED + fold)
        
        train_dataset = LitDataset(train_df[train_df['kfold'] != fold])
        val_dataset = LitDataset(train_df[train_df['kfold'] == fold])
        val_df = train_df[train_df['kfold'] == fold].copy()
            
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                  drop_last=True, shuffle=True, num_workers=0)    
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                                drop_last=False, shuffle=False, num_workers=0)    
        
        sa_complex_dim = 0
        if USE_SELF_ATT:
            sa_complex_dim = benchmark_sa_complex.size(-1)
        
        model = LitModel(benchmark_token = benchmark_token, use_max_pooling = USE_MAX_POOLING,
                         sa_complex_dim = sa_complex_dim).to(DEVICE)
        
        # Update vocabulary size
        model.backbone.resize_token_embeddings(len(tokenizer))
 
        if PHASE=='train':
            model_path = f"{SAVE_DIR}/model_{fold + 1}.bin"
            set_random_seed(SEED + fold)    
 
            optimizer = create_optimizer(model)                        
            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_training_steps = NUM_EPOCHS * len(train_loader) * 11//10,
                num_warmup_steps = 50)
            
            list_val_rmse.append(train(model, model_path, train_loader, val_loader, optimizer, 
                                       num_epochs=NUM_EPOCHS, fold=fold, scheduler=scheduler, ))
        
        elif PHASE=='eval_oof':
            model_path = f"{MODEL_DIR}/model_{fold + 1}.bin"
            model.load_state_dict(torch.load(model_path))
            model.to(DEVICE)
            
            mse, pred_r = eval_mse(model, val_loader)
            val_df['pred'] = pred_r.to('cpu').detach().numpy().copy()
            oof.append(val_df)
            list_val_rmse.append(math.sqrt(mse))
 
        del model
        gc.collect()
        
        print("\nPerformance estimates:")
        print(list_val_rmse)
        print("Mean:", np.array(list_val_rmse).mean())

    if PHASE=='eval_oof':
        oof = pd.concat(oof)

    return oof

In [None]:
def Inference():
    all_predictions = np.zeros((NUM_FOLDS, len(test_df)))

    test_dataset = LitDataset(test_df, inference_only=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             drop_last=False, shuffle=False, num_workers=0)

    for fold in range(NUM_FOLDS):            

        sa_complex_dim = 0
        if USE_SELF_ATT:
            sa_complex_dim = benchmark_sa_complex.size(-1)

        model = LitModel(benchmark_token = benchmark_token, use_max_pooling = USE_MAX_POOLING,
                         sa_complex_dim = sa_complex_dim).to(DEVICE)

        # Update vocabulary size
        model.backbone.resize_token_embeddings(len(tokenizer))

        model_path = f"{MODEL_DIR}/model_{fold + 1}.bin"
        print(f"\nUsing {model_path}")
                            
        model.load_state_dict(torch.load(model_path))    
        
        all_predictions[fold] = predict(model, test_loader)
        
        del model
        gc.collect()

    predictions = all_predictions.mean(axis=0)
    output_df = submission_df.copy()
    output_df.target = predictions
    print(output_df)

    return output_df

# Main

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

if ENV=='colab':
    BASE_DIR = '/content/drive/MyDrive/Colab Notebooks/CLR/input'
    TRAIN_DATA_DIR = BASE_DIR
elif ENV=='kaggle':
    BASE_DIR = '../input/commonlitreadabilityprize'
    TRAIN_DATA_DIR = '../input/step-1-create-folds'

train_df = pd.read_csv(f'{TRAIN_DATA_DIR}/train_folds.csv')
benchmark = train_df[(train_df.target == 0) & (train_df.standard_error == 0)].copy()
train_df.drop(train_df[(train_df.target == 0) & (train_df.standard_error == 0)].index,
              inplace=True)
train_df.reset_index(drop=True, inplace=True)

test_df = pd.read_csv(f"{BASE_DIR}/test.csv")
submission_df = pd.read_csv(f"{BASE_DIR}/sample_submission.csv")

In [None]:
SEED = 1000
NUM_FOLDS = 5
NUM_EPOCHS = 4
BATCH_SIZE = 8
MAX_LEN = 248
EVAL_SCHEDULE = [(0.52, 32), (0.49, 16), (0.48, 8), (0.47, 4), (-1., 2)]
MODEL_NAME = 'roberta-large'
MODEL_VER = 'CLRP_LightBase_031s_RoBERTaL'
 
NUM_HIDDEN_LAYERS = 24
HIDDEN_SIZE = 1024
NUM_BINS = 29
COEF_QWK = 0.0 # coefficient of QWK loss
COEF_BT = 1.0 # coefficient of Bradley-Terry loss

USE_MAX_POOLING = True
USE_SELF_ATT = True
NO_TOKEN_TYPE = False
HAS_DECODER = False

BACKBONE_LR = 2e-5
HIDDEN_WTS_LR = 1e-2
LAYERWISE_LR_DECAY = 0.1

if ENV=='colab':
    MODEL_DIR = f'/content/drive/MyDrive/Colab Notebooks/CLR/{MODEL_VER}'
    SAVE_DIR = MODEL_DIR
    LOAD_BACKBONE_DIR = f'{MODEL_DIR}/backbone'
elif ENV=='kaggle':
    MODEL_DIR = '../input/clrp-lightbase-031s-robertal-dat'
    SAVE_DIR = '.'
    LOAD_BACKBONE_DIR = '../input/robertalarge'

QWKloss = QuadraticWeightedKappaLoss(num_cat=NUM_BINS, device=DEVICE)
BTloss = BradleyTerryLoss()
train_df['bins'] = pd.cut(train_df['target'], bins=NUM_BINS, labels=False)

# Setup Tokenizer
if PHASE=='train':
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.save_pretrained(f'{SAVE_DIR}/backbone')
elif PHASE=='eval_oof' or PHASE=='inference':
    tokenizer = AutoTokenizer.from_pretrained(LOAD_BACKBONE_DIR)
if 'gpt2' in MODEL_NAME.lower():
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Tokenize the benchmark text
benchmark_token = convert_examples_to_features(benchmark['excerpt'].iloc[0], tokenizer, MAX_LEN, return_tensor = True)
if NO_TOKEN_TYPE:
    benchmark_token = (benchmark_token['input_ids'].to(DEVICE), None, benchmark_token['attention_mask'].to(DEVICE))
else:
    benchmark_token = (benchmark_token['input_ids'].to(DEVICE), benchmark_token['token_type_ids'].to(DEVICE), benchmark_token['attention_mask'].to(DEVICE))

# Main
if PHASE=='train' or PHASE=='eval_oof':
    sa_complex = None # Self-Attention Complexity in Pretrained Model
    if USE_SELF_ATT:
        sa_complex = SelfAttention_Complexity(train_df, 'cpu')
        benchmark_sa_complex = SelfAttention_Complexity(benchmark, DEVICE)
    oof_df = Train_or_Validation()

if PHASE=='eval_oof':
    oof_df.to_csv(f'oof_{MODEL_VER}.csv')

if PHASE=='inference':
    sa_complex = None # Self-Attention Complexity in Pretrained Model
    if USE_SELF_ATT:
        sa_complex = SelfAttention_Complexity(test_df, 'cpu')
        benchmark_sa_complex = SelfAttention_Complexity(benchmark, DEVICE)
    submission_df = Inference()
    submission_df.to_csv("submission.csv", index=False)

if os.path.isdir('SelfAttComplex'):
    shutil.rmtree('SelfAttComplex')