# CommonLit - Pytorch Lightning & BERT Benchmark

---

## Library

- Pytorch Lightning
- Transformers (Pretrained BERT)

---

## Library Install

In [None]:
%matplotlib inline
import os
import re
import random
import string
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import pytorch_lightning as pl
from pytorch_lightning import metrics
from pytorch_lightning.trainer import Trainer
import transformers

## Config

In [None]:
class cfg:
    data_dir = '../input/commonlitreadabilityprize'
    max_len = 256
    num_workers = 8
    epoch = 30
    batch_size = 8
    lr = 0.0001
    seed = 0
    BERT_MODEL = '../input/huggingface-bert/bert-base-uncased'

## Utils

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

## Preprocessing

Text Preprocessing Function

In [None]:
def text_preprocessing(text):
    # Delete not-used Word
    e = re.sub("[^a-zA-Z]", " ", text)
    for p in string.punctuation:
        e = e.replace(p, " ")

    # Delete Continuous Blank
    e = re.sub(r" +", r" ", e).strip()

    e = e.lower()

    # Stop Word
    e = nltk.word_tokenize(e)
    e = [word for word in e if not word in set(stopwords.words("english"))]

    # Standardize
    lemma = nltk.WordNetLemmatizer()
    e = [lemma.lemmatize(word) for word in e]
    e = " ".join(e)

    return e

## Dataset

In [None]:
class LitDataset(Dataset):
    def __init__(self, df, tokenizer, text_preprocessing_fn=None, max_len=256, phase='train'):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.text_preprocessing_fn = text_preprocessing_fn
        self.phase = phase

    def __len__(self):
        return len(self.df)

    def _tokenize(self, text):
        # Tokenize
        # BERT Tokenizer
        out = self.tokenizer.encode_plus(
            text,
            None,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length'
        )

        ids = out['input_ids']
        mask = out['attention_mask']
        ttis = out['token_type_ids']

        del out

        return ids, mask, ttis


    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = row['excerpt']

        # Preprocessing
        if self.text_preprocessing_fn is not None:
            text = self.text_preprocessing_fn(text)

        # Tokenize & Encode
        ids, mask, ttis = self._tokenize(text)

        # Concatinate Input
        inp = {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'ttis': torch.tensor(ttis, dtype=torch.long)
        }

        if self.phase != 'test':
            target = torch.tensor(row['target'], dtype=torch.float)
            return inp, target

        else:
            ids = row['id']
            return inp, ids

In [None]:
# Sanity Check
df = pd.read_csv(os.path.join(cfg.data_dir, 'train.csv'))

# BERT Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("../input/bert-base-uncased")  # Offine
# tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)  # OnLine

# Dataset
dataset = LitDataset(df, text_preprocessing_fn=text_preprocessing, tokenizer=tokenizer, max_len=128)

# Print Outputs
text, target = dataset.__getitem__(7)
print(text['ids'])
print(text['mask'])
print(text['ttis'])

## Lightning Data Module

In [None]:
class LitDataModule(pl.LightningDataModule):
    def __init__(self, cfg, tokenizer, text_preprocessing_fn=None):
        """
        ------------------------------------
        Parameters
        cfg: DictConfig
            Config
        tokenizer:
            Pretrained Tokenizer
        text_preprocessing_fn: func
            Text Preprocessing Function: str -> str
        """
        super(LitDataModule, self).__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.text_preprocessing_fn = text_preprocessing_fn


    def prepare_data(self):
        # Load Data
        self.train = pd.read_csv(os.path.join(self.cfg.data_dir, 'train.csv'))
        self.test = pd.read_csv(os.path.join(self.cfg.data_dir, 'test.csv'))

    def setup(self, stage=None):
        # Split trian valid
        train_val_rate = 0.8
        tmp_train = self.train[:int(len(self.train) * train_val_rate)]
        tmp_val = self.train[int(len(self.train) * train_val_rate):]


        self.train_dataset = LitDataset(
            tmp_train,
            tokenizer=self.tokenizer,
            text_preprocessing_fn=self.text_preprocessing_fn,
            max_len=self.cfg.max_len,
            phase='train'
        )

        self.val_dataset = LitDataset(
            tmp_val,
            tokenizer=self.tokenizer,
            text_preprocessing_fn=self.text_preprocessing_fn,
            max_len=self.cfg.max_len,
            phase='train'
        )

        self.test_dataset = LitDataset(
            self.test,
            tokenizer=self.tokenizer,
            text_preprocessing_fn=self.text_preprocessing_fn,
            max_len=self.cfg.max_len,
            phase='test'
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          shuffle=True,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True)


    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          shuffle=False,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.cfg.batch_size,
                          shuffle=False,
                          num_workers=self.cfg.num_workers,
                          pin_memory=True)

In [None]:
# Sanity Check
# BERT Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("../input/bert-base-uncased")  # Offine
# tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)  # OnLine

# DataModule
dm = LitDataModule(
    cfg=cfg,
    tokenizer=tokenizer,
    text_preprocessing_fn=text_preprocessing
)

dm.prepare_data()
dm.setup()

dataloader = dm.train_dataloader()

batch, tar = next(iter(dataloader))

print(batch['ids'].size())
print(batch['mask'].size())
print(tar)

## LightningModule

In [None]:
class LitLightningSystem(pl.LightningModule):
    def __init__(self, net, cfg, criterion, optimizer, scheduler=None):
        """
        ------------------------------------
        Parameters
        net: torch.nn.Module
            Model Network
        cfg: DictConfig
            Config
        optimizer: torch.optim
            Optimizer
        scheduler: torch.optim.lr_scheduler
            Learning Rate Scheduler
        """
        super(LitLightningSystem, self).__init__()
        self.net = net
        self.cfg = cfg
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.weight_paths = []
        self.best_rmse = 1e+9
        self.losses = []
        self.rmses = []

    def configure_optimizers(self):
        if self.scheduler is None:
            return [self.optimizer], []
        else:
            return [self.optimizer], [self.scheduler]

        
    def forward(self, ids, mask, ttis):
        output = self.net(ids=ids, mask=mask, token_type_ids=ttis)
        return output


    def training_step(self, batch, batch_idx):
        inp, target = batch
        out = self.forward(inp['ids'], inp['mask'], inp['ttis'])

        loss = self.criterion(out.view_as(target), target)

        return {'loss': loss}


    def validation_step(self, batch, batch_idx):
        inp, target = batch
        out = self.forward(inp['ids'], inp['mask'], inp['ttis'])

        loss = self.criterion(out.view_as(target), target)

        return {'val_loss': loss, 'outputs': out, 'targets': target}


    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        res = torch.cat([x['outputs'] for x in outputs]).reshape((-1))
        targets = torch.cat([x['targets'] for x in outputs]).reshape((-1))

        # RMSE
        rmse = torch.sqrt(self.criterion(res.view_as(targets), targets))

        # Logging
        self.losses.append(avg_loss.item())
        self.rmses.append(rmse.item())

        # Save Weights
        if rmse.item() < self.best_rmse:
            filename = 'seed_{}_epoch_{}_loss_{:.3f}_rmse_{:.3f}.pth'.format(
                self.cfg.seed, self.current_epoch, avg_loss.item(), rmse.item()
            )
            torch.save(self.net.state_dict(), filename)

            self.best_rmse = rmse.item()

        return None


    def test_step(self, batch, batch_idx):
        inp, ids = batch
        out = self.forward(inp['ids'], inp['mask'], inp['ttis'])

        return {'preds': out, 'ids': ids}


    def test_epoch_end(self, outputs):
        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()
        res = pd.DataFrame(preds, columns=['target'])

        ids = [x['ids'] for x in outputs]
        ids = [list(x) for x in ids]
        ids = list(itertools.chain.from_iterable(ids))

        res.insert(0, 'id', ids)

        sub_file = 'submission.csv'
        res.to_csv(sub_file, index=False)
        self.res = res

        return None

## Bert Network

In [None]:
class BERT_Network(nn.Module):
    def __init__(self, cfg):
        super(BERT_Network, self).__init__()
        self.bert = transformers.BertModel.from_pretrained(cfg.BERT_MODEL)
        self.drop = nn.Dropout(0.5)
        self.out = nn.Linear(768, 1)

    def forward(self, ids, mask, token_type_ids):
        _, output = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False)
        output = self.drop(output)
        output = self.out(output)
        return output

## Train

In [None]:
# Seed
seed_everything(cfg.seed)

# BERT Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("../input/bert-base-uncased")  # Offine
# tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)  # OnLine

# Network
net = BERT_Network(cfg)

# DataModule
dm = LitDataModule(
    cfg=cfg,
    tokenizer=tokenizer,
    text_preprocessing_fn=text_preprocessing
)

# Criterion - MSE (From Pytorch Lightning)
criterion = metrics.MeanSquaredError()

# Optimizer & Scheduler
optimizer = transformers.AdamW(net.parameters(), lr=cfg.lr)
scheculer = CosineAnnealingLR(optimizer, T_max=cfg.epoch, eta_min=0)

# LightningSystem
model = LitLightningSystem(net, cfg, criterion, optimizer, scheculer)

# Trainer  ------------------------------------------------------
trainer = Trainer(
    logger=None,
    max_epochs=cfg.epoch,
    gpus=-1,
    num_sanity_val_steps=0,
)

In [None]:
# Train
trainer.fit(model, datamodule=dm)

In [None]:
# History Plot
loss_df = pd.DataFrame({
    'epoch': np.arange(1, len(model.losses) + 1),
    'loss': model.losses,
    'rmse': model.rmses
})

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(111)

# RMSE
sns.lineplot(x='epoch', y='rmse', data=loss_df, ax=ax)
ax.set_title('RMSE Plot')

plt.show()

## Inference

In [None]:
# Inference
trainer.test(model, datamodule=dm)

In [None]:
model.res

In [None]:
# Create CSV
model.res.to_csv('submission.csv', index=False)