# Libraries

In [None]:
import os
import random
import gc
from pprint import pprint
from tqdm import tqdm
import more_itertools
from collections import OrderedDict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns
sns.set(style='darkgrid')

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from torch.utils.data import RandomSampler

from transformers import AutoModel
from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import AutoModelForMaskedLM
from transformers import DataCollatorForLanguageModeling
from transformers import DataCollatorWithPadding

%matplotlib inline

import sys


COLAB = False

if COLAB:
    sys.path.append('/content/drive/MyDrive/Colab Notebooks/CommonLit/lib')
    from clrp_private_lib import SmartBatchingSampler
    from clrp_private_lib import get_slanted_triangular_lr
    from clrp_private_lib import EarlyStopping

# Configuration

In [None]:
DEBUG = False

# Input Files
TRAIN = '/content/drive/MyDrive/Colab Notebooks/CommonLit/input/train.csv' if COLAB else '../input/commonlitreadabilityprize/train.csv'
TRAIN = '/content/drive/MyDrive/Colab Notebooks/CommonLit/input/train_oof_stratified.csv' if COLAB else '../input/clrp-stratify-on-predictability/train_oof_stratified.csv'   # changed
TEST = '/content/drive/MyDrive/Colab Notebooks/CommonLit/input/test.csv' if COLAB else '../input/commonlitreadabilityprize/test.csv'

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: ', device.type)

SEED = 28
seed_everything(SEED)

cfg ={}

In [None]:
# BERT
BERT = 'bert-base-uncased'

# Distilbert
DISTILBERT = 'distilbert-base-uncased'

# Roberta
ROBERTA = 'roberta-base' if COLAB else '../input/huggingface-roberta-variants/roberta-base'
ROBERTA_LARGE = 'roberta-large' if COLAB else '../input/huggingface-roberta-variants/roberta-large/roberta-large'



cfg ={}

ARCH_PATH = ROBERTA_LARGE

CV_PATH = '../input/clrp-robertalarge-attentions-mask-act'
MODEL_NAME = 'CLRPModelLarge'

CV = False
POST = False

# Data

## Tokenizer

In [None]:
def get_tokenizer():
    return AutoTokenizer.from_pretrained(cfg['tokenizer']['name'])

## Dataset

In [None]:
def clean_text(text):
    text = text.replace('\n', '')
    return text

class CLRPDataset(Dataset):
    
    def __init__(self, df, tokenizer):
        self.df = df
        self.texts = self.df['excerpt'].tolist()
        self.targets = self.df['target'].tolist()
        self.se = self.df['standard_error'].tolist()
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = clean_text(self.texts[index])
        tokenized_text = self.tokenizer.encode_plus(
            text,
            **cfg['tokenizer']['params']
        )
        
        tokenized_text['target'] = self.targets[index]
        tokenized_text['se'] = self.se[index]
        return tokenized_text

## Dataloader

### Datacollator

In [None]:
def get_collator(tokenizer, phase='train'):
    
    def _prepare_collator(name, phase):
        if name=='MLM':
            return DataCollatorForLanguageModeling(tokenizer=tokenizer, 
                                                   **cfg['collator'][phase]['params'])
        elif name=='padding':
            return DataCollatorWithPadding(tokenizer=tokenizer)
        else:
            return None
    
    return _prepare_collator(cfg['collator'][phase]['name'], phase=phase)

### Sampler

In [None]:
def get_SmartBatchingSampler(df, batch_size):
    data_source = df.excerpt.apply(lambda x: x.split(' '))
    return SmartBatchingSampler(data_source=data_source, batch_size=batch_size)

def get_RandomSampler(df):
    data_source = df.excerpt.apply(lambda x: x.split(' '))
    return RandomSampler(data_source=data_source)

# Model

In [None]:
def get_extended_attention_mask(attention_mask):
    extended_attention_mask = (1.0 - attention_mask) * (-1e+4)
    extended_attention_mask = extended_attention_mask.unsqueeze(1).transpose(2, 1)
    return extended_attention_mask

class Attention(nn.Module):
    
    def __init__(self, in_features=768, hidden_state=514):
        super().__init__()
        self.in_features = in_features
        self.hidden_state = hidden_state
        
        self.attention = nn.Sequential(
            nn.Linear(self.in_features, self.hidden_state),
            nn.Tanh(),
            nn.Linear(self.hidden_state, 1),
        )

        self.softmax = nn.Softmax(dim=1)
        
        torch.nn.init.kaiming_normal_(self.attention[0].weight)
        torch.nn.init.kaiming_normal_(self.attention[2].weight)
        
    def forward(self, x, attention_mask):
        extended_attention_mask = get_extended_attention_mask(attention_mask)
        weights = self.attention(x)
        weights += extended_attention_mask
        weights = self.softmax(weights)
        context = torch.sum(weights * x, dim=1)  # sumの方向を変えれば、どの単語の位置が重要か分かる？
        return context

class CLRPModel(nn.Module):
    
    def __init__(self, name, p=0.2, path=None, n_attentions=4, hidden_state=514):
        super(CLRPModel, self).__init__()
        
        self.name = name
        self.path = path
        self.p = p
        self.n_attentions = n_attentions
        self.hidden_state = hidden_state
        
        config = AutoConfig.from_pretrained(name)  # This setting is from https://www.kaggle.com/andretugan/lightweight-roberta-solution-in-pytorch
        config.update(cfg['bert'])
        
        self.bert = AutoModel.from_pretrained(name, config=config)
        if path: self._load_pretrained_weights()
        
        self.in_features = self.bert.pooler.dense.out_features
        
        self.attentions = nn.ModuleList([Attention(in_features=self.in_features, hidden_state=self.hidden_state) for i in range(self.n_attentions)])
        self.regressor = nn.Linear(self.n_attentions * self.in_features, 1)

        torch.nn.init.kaiming_normal_(self.regressor.weight)
        
        
    def _load_pretrained_weights(self):
        model_pretrained = AutoModelForMaskedLM.from_pretrained(self.name)
        checkpoint = torch.load(self.path, map_location=device)
        model_pretrained.load_state_dict(checkpoint['model'])
        
        self.bert.embeddings = model_pretrained.roberta.embeddings
        self.bert.encoder = model_pretrained.roberta.encoder
        
        del model_pretrained
        gc.collect()
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        last_hidden_state, output, hidden_states = self.bert(input_ids=input_ids,
                                                attention_mask=attention_mask,
#                                                   token_type_ids=token_type_ids,
                                                return_dict=False)

        contexts = [self.attentions[i](hidden_states[- (i + 1)], attention_mask) for i in range(self.n_attentions)]
        contexts = torch.cat(contexts, dim=-1)
        contexts = F.gelu(contexts)
        output = self.regressor(contexts)
        
        return last_hidden_state, output

In [None]:
def get_sampler(df=None, batch_size=None):
    if cfg['sampler']['name']=='SmartBatchSampler':
        return get_SmartBatchingSampler(df, batch_size)

### Dataloader

# Model

In [None]:
def get_model(pretrained=True, fold=None):
    if pretrained:
        PRETRAINED_MODEL = os.path.join(PRETRAINED_PATH, f'CLRPModelMLM.tar')
        cfg['model']['path'] = PRETRAINED_MODEL
    else:
        cfg['model']['path'] = None
    
    return CLRPModel(**cfg['model'])

# Loss / Metric

In [None]:
def RMSE(y_pred, y_gt):
    assert y_pred.size() == y_gt.size()
    
    metric = nn.MSELoss()  
    metric = torch.sqrt(metric(y_pred, y_gt))
    return metric

In [None]:
def get_loss_fn():
    if cfg['loss']['name']=='RMSE':
        return RMSE
    elif cfg['loss']['name']=='MSE':
        return nn.MSELoss(reduction='mean')
    elif cfg['loss']['name']=='KLdiv':
        return KLdiv_for_normal
    
def get_metric_fn():
    if cfg['metric']['name']=='RMSE':
        return RMSE
    elif cfg['loss']['name']=='MSE':
        return nn.MSELoss(reduction='mean')

# Optimizer

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

from transformers import get_cosine_schedule_with_warmup
from transformers import get_constant_schedule
from transformers import get_constant_schedule_with_warmup
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW

from clrp_private_lib import get_slanted_triangular_lr

def get_optim(model_parameters):
    if cfg['optim']['name']=='RAdam':
        return RAdam(model_parameters, **cfg['optim']['params'])
    elif cfg['optim']['name']=='AdamW':
        return AdamW(model_parameters, **cfg['optim']['params'])
    elif cfg['optim']['name']=='Adam':
        return Adam(model_parameters, **cfg['optim']['params'])
    
def get_scheduler(optim):
    if cfg['scheduler']['name']=='constant':
        return get_constant_schedule(optimizer=optim)
    elif cfg['scheduler']['name']=='cosine_with_warmup':
        return get_cosine_schedule_with_warmup(optimizer=optim, **cfg['scheduler']['params'])
    elif cfg['scheduler']['name']=='constant_with_warmup':
        return get_constant_schedule_with_warmup(optimizer=optim, **cfg['scheduler']['params'])
    elif cfg['scheduler']['name']=='linear_with_warmup':
        return get_linear_schedule_with_warmup(optimizer=optim, **cfg['scheduler']['params'])
    elif cfg['scheduler']['name']=='slanted_triangular':
        return get_slanted_triangular_lr(optimizer=optim, **cfg['scheduler']['params'])

In [None]:
def get_layerwise_params_to_optimize(model, lr_reg, lr_bert, wd_reg=0, wd_bert=1e-1):
    
    def contains_no_decay_attr(attr_name):
        return any(no_decay_attr in attr_name for no_decay_attr in ['bias'])
    
    def layer_lr(name, lrs):
        groups = []
        for i in range(0, 12):
            groups.append([f'layer.{j}.' for j in range(2 * i, 2 * (i + 1))]) # 12 groups

        for i in range(12):
            if any(layer in name for layer in groups[i]):
                return lrs[i]
    
    non_decay_flag = lambda name: ("regressor" in name) 
    
    # regressor
    params = [{
        'params': [param for name, param in model.named_parameters() if non_decay_flag(name)&(not contains_no_decay_attr(name))],
        'weight_decay': wd_reg,
        'lr': lr_reg
    }]
    params += [{
        'params': [param for name, param in model.named_parameters() if non_decay_flag(name)&(contains_no_decay_attr(name))],
        'weight_decay': 0.0,
        'lr': lr_reg
    }]
    
    # bert layer
    groups = []
    for i in range(0, 12):
        groups.append([f'layer.{j}.' for j in range(2*i, 2*(i+1))]) # 12 groups
            
    group_freeze = ['embeddings']
    
    for i, (name, bert_params) in enumerate(model.bert.named_parameters()):
        if any(layer in name for layer in group_freeze):
            bert_params.requires_grad_(False)
        else:
            params += [{
                'params': bert_params,
                'weight_decay': wd_bert if not contains_no_decay_attr(name) else 0.0,
                'lr': layer_lr(name, lrs=lr_bert) if not "pooler" in name else lr_reg
            }]
        
    return params

# Training / Inference

In [None]:
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

## Utils

In [None]:
class Gauge:
    
    def __init__(self):
        self.gauge = 0
        self.count = 0
        
    def accumulate(self, gauge):
        self.gauge += gauge
        self.count += 1
        
    def get_mean(self, root=False):
        return np.sqrt(self.gauge / self.count) if root else self.gauge / self.count

In [None]:
def get_val_spacing(val_metric):
    if val_metric > 0.5:
        return 16
    elif val_metric > 0.49:
        return 8
    elif val_metric > 0.48:
        return 4
    elif val_metric > 0.47:
        return 2
    else:
        return 1

def get_dls_for_n_fold(df, fold, tokenizer):
    train_df = df.loc[df.fold!=fold].reset_index(drop=True)
    val_df = df.loc[df.fold==fold].reset_index(drop=True)
    
    train_ds = CLRPDataset(
        train_df,
        tokenizer=tokenizer
    )
    
    val_ds = CLRPDataset(
        val_df, 
        tokenizer=tokenizer
    )
    
    if cfg['sampler']['name'] is not None:
        cfg['dl']['train']['sampler'] = get_sampler(df=train_df, 
                                                    batch_size=cfg['dl']['train']['batch_size'])
    else:
        cfg['dl']['train']['shuffle'] = True
    
    train_dl = DataLoader(train_ds, **cfg['dl']['train'])
    val_dl = DataLoader(val_ds, **cfg['dl']['val'])
    
    return train_dl, val_dl

def get_modules(model):
    loss_fn = get_loss_fn()
    metric_fn = get_metric_fn()
    
    params_to_optimize = get_layerwise_params_to_optimize(model,
                                                **cfg['optim']['params_control'])
    optim = get_optim(params_to_optimize)
    scheduler = get_scheduler(optim)
    
    return loss_fn, metric_fn, optim ,scheduler

def get_inputs(batch):
    keys_to_input = ['input_ids', 'attention_mask', 'token_type_ids']
    inputs = {key: value.to(device) for key, value in batch.items() if key in keys_to_input}
    return inputs

def get_targets(batch):
    keys_to_output = ['target', 'se']
    targets = {key: value.view(-1, 1).to(device) for key, value in batch.items() if key in keys_to_output}
    if cfg['loss']['name']!='KLdiv':
        return targets['target'] 
    else:
        return torch.cat((targets['target'], targets['se']), dim=1)

## Components

In [None]:
def val_fn(model, dl, loss_fn, metric_fn):
    scaler = GradScaler()
    
    losses = Gauge()
    metrics = Gauge()
    
    model.eval()
    model.to(device)
    
    with torch.no_grad():
        for i, batch in enumerate(dl):
            inputs = get_inputs(batch)
            targets = get_targets(batch)
            
            with autocast():
                last_hidden_states, outputs = model(**inputs)
                loss = loss_fn(outputs, targets)
                metric = metric_fn(outputs, targets)
            
            losses.accumulate(loss.item())
            metrics.accumulate(metric.item())
    
    return losses.get_mean(), metrics.get_mean(root=True)

# Calculate CV Score

In [None]:
from sklearn.metrics import mean_squared_error

cfg = torch.load(os.path.join(CV_PATH, "cfg.pt"), map_location=device)


cfg['model']['name'] = ARCH_PATH
cfg['tokenizer']['name'] = ARCH_PATH

pprint(cfg)

In [None]:
def pooled_last_hidden_state(last_hidden_states, pool='mean'):
    last_hidden_states = last_hidden_states.detach().cpu().numpy()
    if pool=='max':
        return last_hidden_states.max(axis=1)
    elif pool=='mean':
        return last_hidden_states.mean(axis=1)

    
def val_fn_cv(model, dl):
    scaler = GradScaler()
    preds = []
    lhs = [] # last hidden state
    
    model.eval()
    model.to(device)
    
    progress_bar = tqdm(dl, desc='cv')
    
    with torch.no_grad():
        for i, batch in enumerate(progress_bar):
            inputs = get_inputs(batch)
            
            with autocast():
                last_hidden_states, outputs = model(**inputs)
            
            preds.append(outputs.detach().cpu().numpy())
            lhs.append(pooled_last_hidden_state(last_hidden_states, pool='max'))
    
    preds = np.concatenate(preds)
    lhs = np.concatenate(lhs)
    
    return lhs, preds

def main_cv():
    seed_everything(SEED)
    
    lhs_list = []
    
    df = pd.read_csv(TRAIN)
    tokenizer = get_tokenizer()
    
    for fold in range(cfg['train']['n_folds']):
        train_dl, val_dl = get_dls_for_n_fold(df, fold, tokenizer)
        
        model = get_model(pretrained=False)
        
        PATH = os.path.join(CV_PATH, MODEL_NAME + f'_fold{fold}.tar')
        saved_contents = torch.load(PATH, map_location=device)
        model.load_state_dict(saved_contents['model'])
        
        if fold==0:
            cfg_for_train = saved_contents['cfg']
            print('Configuration for training:')
            print()
            pprint(cfg_for_train)
            print()
        
        print('Fold:', fold)
        
        inputs = {'model': model,
                  'dl': val_dl}
        
        lhs, preds = val_fn_cv(**inputs)
        df.loc[df.fold==fold, 'oof'] = preds
        if fold==0:
            lhs_cols = [f'lhs_{i}' for i in range(lhs.shape[1])]
            df[lhs_cols] = np.nan
        df.loc[df.fold==fold, lhs_cols] = lhs

    return df


def RMSE_(y_pred, y_gt):
    mse = mean_squared_error(y_pred, y_gt)
    return np.sqrt(mse)

def oof_vs_target(df, y='oof'):
    temp_df = pd.DataFrame()
    temp_df['x'] = np.linspace(-3.5, 1.5, 10)
    temp_df['y'] = temp_df['x']

    plt.figure(figsize=(8, 8))
    sns.scatterplot(data=df, x='target', y=y, label=f'{y} vs target', hue='fold', palette='bright')
    sns.lineplot(data=temp_df, x='x', y='y', color='orange')
    plt.title('OOF Prediction vs Target')
    plt.legend()
    plt.show()

In [None]:
%%time
if CV:
    df = main_cv()
    df.to_csv(os.path.join(CV_PATH if COLAB else '.', 'oof_df.csv'), index=False)

    print('CV score: ', RMSE_(df['target'], df['oof']))
    oof_vs_target(df, y='oof')

# Postprocessing

In [None]:
from sklearn import linear_model

if POST:
    lm = linear_model.LinearRegression()
    lm.fit(df.loc[:, ['oof']], df['target'].values)

    df['oof_post'] = lm.predict(df.loc[:, ['oof']])

    score_oof_post = RMSE_(df['oof_post'], df['target'])
    print('RMSE (oof post): ', score_oof_post)

# Inference

In [None]:
def main_infer():
    seed_everything(SEED)
    
    df = pd.read_csv(TEST)
    df['target'] = 0.
    df['standard_error'] = 0.
    
    tokenizer = get_tokenizer()
    
    for fold in range(cfg['train']['n_folds']):
        print('Fold:', fold)

        test_ds = CLRPDataset(
                    df, 
                    tokenizer=tokenizer
                    )
    
        test_dl = DataLoader(test_ds, **cfg['dl']['val'])
        
        # if fold in [0, 1]:
        #     CV_PATH = TRAINED[0]
        # elif fold in [2, 3]:
        #     CV_PATH =TRAINED[1]
        # else:
        #     CV_PATH = TRAINED[2]

        model = get_model(pretrained=False)
        PATH = os.path.join(CV_PATH, MODEL_NAME + f'_fold{fold}.tar')
        state_dict = torch.load(PATH, map_location=device)['model']
        model.load_state_dict(state_dict)

        inputs = {'model': model,
                  'dl': test_dl}
        
        lhs, preds = val_fn_cv(**inputs)
        df['target'] = df['target'] + np.concatenate(preds)
    
    df['target'] = df['target'] / cfg['train']['n_folds']
    return df

In [None]:
if not CV:
    df = main_infer()

    if POST:
        df['target'] = lm.predict(df.loc[:, ['target']])

    df = df[['id', 'target']]
    df.to_csv('submission.csv', index=False)