In [None]:
from argparse import ArgumentParser, Namespace
from datetime import datetime
from os import cpu_count
from pathlib import Path
from typing import Dict, List, Optional, Union
import gc
import pickle

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
from torch.distributions.beta import Beta
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_constant_schedule_with_warmup,
)
import numpy as np
import os
import pandas as pd
import random
import torch
import torch.nn as nn

In [None]:
args = Namespace(
    data=Path('/kaggle/input/commonlitreadabilityprize'),
    models=Path('models'),
    infer_path=Path('/kaggle/input/7-entrenamiento-transformer'), # CAMBIAR POR EL PATH AL DATASET RESULTADO DE ENTRENAMIENTO
    seed=2021,
    kfold_seed=2021,
    bs=8, 
    n_folds=5,
    seq_len=200,
    model_name='roberta-base',
    trf_do=0.,
    lr=1e-5,
    epochs=1,
    val_steps=100,
    mode='infer', # train/infer
)

# Funciones auxiliares

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONASSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# Datos

In [None]:
class ReadabilityDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=200):
        super().__init__()
        self.df = df
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.labeled = 'target' in df
        
        if self.labeled:
            self.target = torch.tensor(df.target.values, dtype=torch.float)
            self.stderr = torch.tensor(df.standard_error.values, dtype=torch.float)
            #self.bin = torch.tensor(df.bin.values, dtype=torch.long)
        
        texts = list(df.excerpt.values)
        self.tokens = tokenizer(
            texts, 
            max_length=max_len, 
            truncation=True, 
            padding='max_length', 
            return_tensors='pt', 
            add_special_tokens=True
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        ids = self.tokens['input_ids'][idx].clone()
        mask = self.tokens['attention_mask'][idx].clone()

        if self.labeled:
            target = self.target[idx]

        return (ids, mask, target) if self.labeled else (ids, mask)

# Definición del modelo
Dimensiones:
* `bs`: batch size
* `max_len`: longitud de la secuencia
* `d_model`: dimensión del transformer (768 roberta-base, 1024 roberta-large)

In [None]:
class ReadabilityModel(nn.Module):

    def __init__(self, args, body):
        super().__init__()

        self.args = args
        self.body = body
        
        self.out_features = body(torch.zeros(1, 1, dtype=torch.long)).hidden_states[-1].shape[-1]
        self.head = nn.Linear(self.out_features, 1) # 768 base, 1024 large
        self.head.bias.data.fill_(-0.9593187699947071)

    def forward(self, ids, mask):
        x = self.body(input_ids=ids, attention_mask=mask)
        # 
        x = x['hidden_states']  # 13, bs, max_len, d_model (ej: 13, 8, 200, 768)
        x = x[-1]               # 8, 200, 768
        x = torch.mean(x, 1)    # 8, 768
        x = self.head(x)        # 8, 1
        x = x.squeeze(-1)       # 8
        return x

# Función de pérdidas

In [None]:
def loss_fn(preds, targs):
    return torch.sqrt(nn.MSELoss()(preds, targs))

# Guarda el mejor modelo

In [None]:
class Checkpointer:
    def __init__(self, base_name, path):
        self.best_loss = None
        self.best_path = None
        self.base_name = base_name
        self.path = path

    def on_validation_end(self, model, epoch, loss) -> None:
        if (self.best_loss is None) or (loss < self.best_loss):
            #print(f"Loss {loss} better than {self.best_loss}")
            self.path.mkdir(exist_ok=True, parents=True)
            stem = f'{self.base_name}_epoch_{epoch}_loss_{loss:.6f}'
            out_p = self.path / f'{stem}.pt'
            with out_p.open('wb') as f:
                torch.save(model.state_dict(), f)
            if self.best_path is not None:
                self.best_path.unlink()
            self.best_loss = loss
            self.best_path = out_p

# Validación

In [None]:
def validate(model, valid_dl):
    is_training = model.training
    model.eval()
    preds = []
    targs = []
    with torch.no_grad():
        for batch in tqdm(valid_dl, leave=False):
            ids, mask, targ = ( t.cuda() for t in batch )
            pred = model(ids, mask)
            targs.append(targ)
            preds.append(pred)
        preds = torch.cat(preds)
        targs = torch.cat(targs)
        #print(preds.shape, targs.shape)
        loss = loss_fn(preds, targs).cpu()
    if is_training:
        model.train()
    return loss

# Predicción

In [None]:
def predict(model, test_dl):
    is_training = model.training
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in tqdm(test_dl, leave=False):
            if len(batch) == 3:
                ids, mask, _ = ( t.cuda() for t in batch )
            else:
                ids, mask = ( t.cuda() for t in batch )
            pred = model(ids, mask).detach().cpu()
            preds.append(pred)
        preds = torch.cat(preds)
    if is_training:
        model.train()
    return preds

# Entrenamiento de 1 fold

In [None]:
def train_fold(fold, df, train_idx, valid_idx, ts):
    print(f"Training fold {fold}")

    # Model
    model_config = AutoConfig.from_pretrained(f'cfgs/{args.model_name}', add_pooling_layer=False)
    model_config.output_hidden_states = True
    model_config.hidden_dropout_prob = args.trf_do
    model_config.attention_probs_dropout_prob = args.trf_do
    body = AutoModel.from_pretrained(args.model_name, config=model_config, add_pooling_layer=False)
    model = ReadabilityModel(args, body)

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(f'toks/{args.model_name}', config=model_config)

    # Datasets
    train_ds = ReadabilityDataset(df.iloc[train_idx], tokenizer, max_len=args.seq_len)
    valid_ds = ReadabilityDataset(df.iloc[valid_idx], tokenizer, max_len=args.seq_len)

    # Dataloader
    train_dl = DataLoader(
        dataset=train_ds,
        batch_size=args.bs,
        num_workers=0, #cpu_count() // 4,
        drop_last=True,
        shuffle=True
    )
    
    valid_dl = DataLoader(
        dataset=valid_ds,
        batch_size=args.bs * 8,
        num_workers=0, #cpu_count() // 4,
    )
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    model.train().cuda()

    model_p = args.models / ts
    prefix = f'{args.model_name}_fold_{fold}'
    checkpointer = Checkpointer(prefix, model_p)    
    num_training_steps = len(train_dl) * args.epochs
    pbar = tqdm(range(num_training_steps))

    step = 0
    val_loss = None

    for epoch in range(args.epochs):
        for nb, batch in enumerate(train_dl):
            
            optimizer.zero_grad()

            ids, mask, targ = ( t.cuda() for t in batch )

            pred = model(ids, mask)
            loss = loss_fn(pred, targ)
                
            loss.backward()
            optimizer.step()

            # Valida cada n pasos
            step += 1
            if step % args.val_steps == 0 or step == num_training_steps:
                val_loss = validate(model, valid_dl)

                # Checkpoint
                checkpointer.on_validation_end(model, epoch, val_loss.item())
                
            # update tqdm
            postfix = {
                'loss': loss.item(),
                'val_loss': val_loss.item() if val_loss else '-',
                'best': checkpointer.best_loss,
            }

            pbar.set_postfix(postfix)
            pbar.update(1)

    # Guarda args
    with (model_p / f'{prefix}.args').open('wb') as f:
        pickle.dump(args, f)

    # Predice con el mejor checkpoint
    model.load_state_dict(torch.load(checkpointer.best_path))
    model.eval()
    df.loc[valid_idx, 'pred'] = predict(model, valid_dl)


# Entrenamiento K folds

In [None]:
def train():
    args.models.mkdir(exist_ok=True)
    seed_everything(args.seed)

    # Crea dirs para la config del modelo y datos del tokenizador
    for d in ['cfgs', 'toks']:
        Path(d).mkdir(exist_ok=True)
        
    # Descarga config del modelo + tokenizer para la inferencia (no tendremos internet)
    model_config = AutoConfig.from_pretrained(args.model_name, add_pooling_layer=False)
    model_config.save_pretrained(f'cfgs/{args.model_name}')
    tok = AutoTokenizer.from_pretrained(args.model_name)    
    tok.save_pretrained(f'toks/{args.model_name}')

    # Divide conjunto de datos en train/val
    df = pd.read_csv('/kaggle/input/commonlitreadabilityprize/train.csv')
    df['pred'] = pd.NA
    ts = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S')
    cv = KFold(n_splits=args.n_folds, shuffle=True, random_state=args.kfold_seed)

    for fold, (train_idx, valid_idx) in enumerate(cv.split(df)):
        train_fold(fold, df, train_idx, valid_idx, ts)
        gc.collect()
        torch.cuda.empty_cache()

    # Validación cruzada (CV) de todo el dataset
    cv = mean_squared_error(df.target, df.pred, squared=False)
    print(f"cv={cv}")
    df.to_csv(args.models / ts / f'{args.model_name}_oof_preds_{cv:.6f}.csv', index=False)


# Inferencia

In [None]:
def infer():
    # Carga config del modelo, modelo y tokenizador
    model_config = AutoConfig.from_pretrained(args.infer_path / 'cfgs' / args.model_name, add_pooling_layer=False)
    model_config.output_hidden_states = True
    model_config.hidden_dropout_prob = args.trf_do
    model_config.attention_probs_dropout_prob = args.trf_do
    body = AutoModel.from_config(model_config, add_pooling_layer=False)
    model = ReadabilityModel(args, body)
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(args.infer_path / 'toks' / args.model_name, config=model_config)

    # Carga datos de test
    test_df = pd.read_csv(args.data / 'test.csv')

    # Dataset y dataloader
    test_ds = ReadabilityDataset(test_df, tokenizer, max_len=args.seq_len)
    test_dl = DataLoader(dataset=test_ds, batch_size=args.bs * 8, num_workers=0)

    # Acumula las predicciónes de los 5 folds aquí
    preds_l = []

    for model_p in (args.infer_path / args.models).iterdir():
        for pt_p in model_p.glob('*.pt'):
            print(f"Infer {str(pt_p)}")
            # Carga los params del .pt
            model.load_state_dict(torch.load(pt_p))
            
            preds = predict(model, test_dl)
            preds_l.append(preds)

    all_preds = torch.stack(preds_l).mean(dim=0)
    
    test_df['target'] = all_preds
    sub_df = test_df[['id', 'target']]

    sub_df.to_csv('submission.csv', index=False)

In [None]:
if args.mode == 'train':
    train()
elif args.mode == 'infer':
    infer()