# End to End Example: Fine Tuning RoBERTa for Sequence Classification with PyTorch Lightning and Optuna

## Why
### I needed to learn how to optimize and organize networks with greater efficiency. To that end, I put together this project to share some generic learnings on optimizing deep networks with Optuna and PyTorch Lightning based on the [CommonLit Readability Prize](https://www.kaggle.com/c/commonlitreadabilityprize) data. To really learn these tools, I went through some pain to refactor a kernel that I originally created for the same competition [from here](https://www.kaggle.com/justinchae/crp-regression-with-roberta-and-lightgbm). Specifically, in this kernel, I leverage some neat code organization with [PyTorch Lightning](https://www.pytorchlightning.ai/) and efficiency in automating the tuning process with [Optuna](https://optuna.org/). Although this kernel is nlp-centric, it should demonstrate the general framework for fine-tuning hyperparamters in a transfer learning approach to deep learning.

## You may be interested in this kernel if...

### - You are in the CommonLit Readability Competition and need some help fine tuning your networks
### - You want to know how to refactor a PyTorch project into a PyTorch Lightning Project
### - You need to see some working examples of fine tuning a BERT model with PyTorch Lightning
### - You are looking for some integrated project samples that combine Optuna with PyTorch Lightning on an open dataset
### - You want to combine k-fold cross validation with PyTorch-Lightining and Optuna optimization

## Bottom Line Results

### With all other things equal, after I learned the optimal parameters by running this kernel, I applied them to a prior submission that scored .497 on the CommonLit Public Leaderboard. The Results? After applying the optimal hyperparameters from this Optuna-PyTorch Lightning integration, the public score of my [prior kernel improved](https://www.kaggle.com/justinchae/crp-regression-with-roberta-and-lightgbm) from .497 to .491. Although a gain of .006 is not a huge number, it still shows the utility of being able to automate the hyperparameter tuning process to improve a model's peformance without fundamentally doing something different with your data. For context, the gain in performance is about 1%, as a result, the question to consider how much value a 1% gain in model performance represents to your specific problem set. I should note that since I created this kernel to learn how to use new tools, the kernel does not perform at a comptetive level on its own with a public score of about .52; instead, this forms a solid foundation to search for hyperparameters and gain a slight edge as part of a bigger prediction scheme.

In [None]:
%%capture

# install necessary libraries from input
# import progressbar library for offline usage
!ls ../input/progresbar2local
!pip install progressbar2 --no-index --find-links=file:///kaggle/input/progresbar2local/progressbar2

# import text stat library for additional ml data prep
!ls ../input/textstat-local
!pip install textstat --no-index --find-links=file:///kaggle/input/textstat-local/textstat 

In [None]:
# set to 16 bit precision to cut compute requirements/increase batch size capacity
USE_16_BIT_PRECISION = True
# set a seed value for consistent experimentation; optional, else leave as None
SEED_VAL = 42
# set a train-validation split, .7 means 70% of train data and 30% to validation set
TRAIN_VALID_SPLIT = .8
# set some hyperparameters as global variables here
N_OPTUNA_TRIALS = 12
MAX_EPOCHS = 4
BATCH_SIZE = 16

In [None]:
# if running with TPU, uncomment this cell and run to install 
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
# TODO: make it easy to toggle between gpu and tpu

In [None]:
import kaggle_config
from kaggle_config import (WORKFLOW_ROOT, DATA_PATH, CACHE_PATH, FIG_PATH, 
                           MODEL_PATH, ANALYSIS_PATH, KAGGLE_INPUT, 
                           CHECKPOINTS_PATH, LOGS_PATH)

INPUTS, DEVICE = kaggle_config.run()
KAGGLE_TRAIN_PATH = kaggle_config.get_train_path(INPUTS)
KAGGLE_TEST_PATH = kaggle_config.get_test_path(INPUTS)

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find

import torchmetrics

import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.samplers import TPESampler, RandomSampler, CmaEsSampler
from optuna.visualization import (plot_intermediate_values
                                  , plot_optimization_history
                                  , plot_param_importances)

from sklearn.model_selection import KFold

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split

import tensorflow as tf

from transformers import (RobertaForSequenceClassification
                          , RobertaTokenizer
                          , AdamW
                          , get_linear_schedule_with_warmup)

import os
import pandas as pd
import numpy as np

import gc
from functools import partial

from typing import List, Dict
from typing import Optional
from argparse import ArgumentParser

import random

if SEED_VAL:
    random.seed(SEED_VAL)
    np.random.seed(SEED_VAL)
    seed_everything(SEED_VAL)
    
NUM_DATALOADER_WORKERS = os.cpu_count()

try: 
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    n_tpus = len(tf.config.list_logical_devices('TPU'))
except ValueError:
    n_tpus = 0

ACCELERATOR_TYPE = {}
ACCELERATOR_TYPE.update({'gpus': torch.cuda.device_count() if torch.cuda.is_available() else None})
ACCELERATOR_TYPE.update({'tpu_cores': n_tpus if n_tpus > 0 else None})
# still debugging how to best toggle between tpu and gpu; there's too much code to configure to work simply
print("ACCELERATOR_TYPE:\n", ACCELERATOR_TYPE)

PRETTRAINED_ROBERTA_BASE_MODEL_PATH = "/kaggle/input/pre-trained-roberta-base"
PRETRAINED_ROBERTA_BASE_TOKENIZER_PATH = "/kaggle/input/tokenizer-roberta"
PRETRAINED_ROBERTA_BASE_TOKENIZER = RobertaTokenizer.from_pretrained(PRETRAINED_ROBERTA_BASE_TOKENIZER_PATH)

In [None]:
# %%capture
# !pip install 'neptune-client[pytorch-lightning]'
# !pip install neptune-client
# from neptune.new.integrations.pytorch_lightning import NeptuneLogger
# from pytorch_lightning.loggers.neptune import NeptuneLogger

# neptune_api_token = """<token>"""
# neptune_project="""justinhchae/kaggle-crp-pytorchlightning-plus-optuna"""
# neptune_name='test-run'

# NEPTUNE_LOGGER = NeptuneLogger(api_token=neptune_api_token
#                                , project=neptune_project
#                                , name=neptune_name)

# debugging this neptune implementation; seems to have compatibility issues

In [None]:
"""Implementing Lightning instead of torch.nn.Module
"""
class LitRobertaLogitRegressor(pl.LightningModule):
    def __init__(self, pre_trained_path: str
                     , output_hidden_states: bool = False
                     , num_labels: int = 1
                     , layer_1_output_size: int = 64
                     , layer_2_output_size: int = 1
                     , learning_rate: float = 1e-5
                     , task_name: Optional[str] = None
                     , warmup_steps: int = 100
                     , weight_decay: float = 0.0
                     , adam_epsilon: float = 1e-8
                     , batch_size: Optional[int] = None
                     , train_size: Optional[int] = None
                     , max_epochs: Optional[int] = None
                     , n_gpus: Optional[int] = 0
                     , n_tpus: Optional[int] = 0 
                     , accumulate_grad_batches = None
                ):
        """refactored from: https://www.kaggle.com/justinchae/my-bert-tuner and https://www.kaggle.com/justinchae/roberta-tuner
        """
        super(LitRobertaLogitRegressor, self).__init__()
        
        # this saves class params as self.hparams
        self.save_hyperparameters()
        
        self.model = RobertaForSequenceClassification.from_pretrained(self.hparams.pre_trained_path
                                                                        , output_hidden_states=self.hparams.output_hidden_states
                                                                        , num_labels=self.hparams.num_labels
                                                                        )

        self.accelerator_multiplier = n_gpus if n_gpus > 0 else 1
        
        self.config = self.model.config
        self.parameters = self.model.parameters
        self.save_pretrained = self.model.save_pretrained
        # these layers are not currently used, tbd in future iteration
        self.layer_1 = torch.nn.Linear(768, layer_1_output_size)
        self.layer_2 = torch.nn.Linear(layer_1_output_size, layer_2_output_size)
        
        def rmse_loss(x, y):
            criterion = F.mse_loss
            loss = torch.sqrt(criterion(x, y))
            return loss
        
        # TODO: enable toggle for various loss funcs and torchmetrics package
        self.loss_func = rmse_loss
#         self.eval_func = rmse_loss   
        
    def setup(self, stage=None) -> None:
        if stage == 'fit':
            # when this class is called by trainer.fit, this stage runs and so on
            # Calculate total steps
            tb_size = self.hparams.batch_size * self.accelerator_multiplier
            ab_size = self.hparams.accumulate_grad_batches * float(self.hparams.max_epochs)
            self.total_steps = (self.hparams.train_size // tb_size) // ab_size
        
    def extract_logit_only(self, input_ids, attention_mask) -> float:
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logit = output.logits
        logit = logit.cpu().numpy().astype(float)
        return logit
    
    def extract_hidden_only(self, input_ids, attention_mask) -> np.array:
        output = self.model(input_ids=input_ids, attention_mask=input_ids)
        hidden_states = output.hidden_states
        x = torch.stack(hidden_states[-4:]).sum(0)
        m1 = torch.nn.Sequential(self.layer_1
                                 , self.layer_2
                                 , torch.nn.Flatten())
        x = m1(x)
        x = torch.squeeze(x).cpu().numpy()
        
        return x
        
    def forward(self, input_ids, attention_mask) -> torch.Tensor:
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        x = output.logits
        return x
    
    def training_step(self, batch, batch_idx: int) -> float:
        # refactored from: https://www.kaggle.com/justinchae/epoch-utils
        labels, encoded_batch, kaggle_ids = batch
        input_ids = encoded_batch['input_ids']
        attention_mask = encoded_batch['attention_mask']
        # per docs, keep train step separate from forward call
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        y_hat = output.logits
        # quick reshape to align labels to predictions
        labels = labels.view(-1, 1)
        loss = self.loss_func(y_hat, labels)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx: int) -> float:
        # refactored from: https://www.kaggle.com/justinchae/epoch-utils
        labels, encoded_batch, kaggle_ids = batch
        input_ids = encoded_batch['input_ids']
        attention_mask = encoded_batch['attention_mask']
        # this self call is calling the forward method
        y_hat = self(input_ids, attention_mask)
        # quick reshape to align labels to predictions
        labels = labels.view(-1, 1)
        loss = self.loss_func(y_hat, labels)
        self.log('val_loss', loss)
        return loss
    
    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
        # creating this predict method overrides the pl predict method
        _, encoded_batch, kaggle_ids = batch
        
        input_ids = encoded_batch['input_ids']
        attention_mask = encoded_batch['attention_mask']
        # this self call is calling the forward method
        y_hat = self(input_ids, attention_mask)
        # convert to numpy then list like struct to zip with ids
        y_hat = y_hat.cpu().numpy().ravel()
        # customizing the predict behavior to account for unique ids
        predictions = list(zip(kaggle_ids, y_hat))
        predictions = pd.DataFrame(predictions, columns=['id', 'target'])
        
        return predictions
    
    def configure_optimizers(self) -> torch.optim.Optimizer:
        # Reference: https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/text-transformers.html
        model = self.model
        
        no_decay = ["bias", "LayerNorm.weight"]
        
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1}
        
        return [optimizer], [scheduler]

In [None]:
def my_collate_fn(batch
                 , tokenizer
                 , max_length: int = 100
                 , return_tensors: str = 'pt'
                 , padding: str = "max_length"
                 , truncation: bool = True
                 ):
    # source: https://www.kaggle.com/justinchae/nn-utils
    labels = []
    batch_texts = []
    kaggle_ids = []

    for (_label, batch_text, kaggle_id) in batch:
        if _label is not None:
            labels.append(_label)
        
        batch_texts.append(batch_text)
        kaggle_ids.append(kaggle_id)
    
            
    if _label is not None:
        labels = torch.tensor(labels, dtype=torch.float)
    
    encoded_batch = tokenizer(batch_texts
                              , return_tensors=return_tensors
                              , padding=padding
                              , max_length=max_length
                              , truncation=truncation)

    return labels, encoded_batch, kaggle_ids


class CommonLitDataset(Dataset):
    def __init__(self
                 , df
                 , text_col: str = 'excerpt'
                 , label_col: str = 'target'
                 , kaggle_id: str = 'id'
                 , sample_size: Optional[str] = None
                ):
        self.df = df if sample_size is None else df.sample(sample_size)
        self.text_col = text_col
        self.label_col = label_col
        self.kaggle_id = kaggle_id
        self.num_labels = len(df[label_col].unique()) if label_col in df.columns else None
        # source: https://www.kaggle.com/justinchae/nn-utils
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        result = None
        text = self.df.iloc[idx][self.text_col]
        kaggle_id = self.df.iloc[idx][self.kaggle_id]
        
        if 'target' in self.df.columns:
            target = self.df.iloc[idx][self.label_col]
            return target, text, kaggle_id     
        else:
            return None, text, kaggle_id


class CommonLitDataModule(pl.LightningDataModule):
    def __init__(self
                 , tokenizer
                 , train_path
                 , collate_fn=None
                 , max_length: int = 280
                 , batch_size: int = 16
                 , valid_path: Optional[str] = None
                 , test_path: Optional[str] = None
                 , train_valid_split: float = .6
                 , dtypes=None
                 , shuffle_dataloader: bool = True
                 , num_dataloader_workers: int = NUM_DATALOADER_WORKERS
                 , kfold: Optional[dict] = None):
        super(CommonLitDataModule, self).__init__()
        self.train_path = train_path
        self.valid_path = valid_path
        self.test_path = test_path
        self.train_valid_split = train_valid_split
        self.dtypes = {'id': str} if dtypes is None else dtypes
        self.train_size = None
        self.train_df, self.train_data = None, None
        self.valid_df, self.valid_data = None, None
        self.test_df, self.test_data = None, None
        if collate_fn is not None:
            self.collate_fn = partial(collate_fn
                                      , tokenizer=tokenizer
                                      , max_length=max_length) 
        else:
            
            self.collate_fn = partial(my_collate_fn
                                      , batch=batch_size
                                      , tokenizer=tokenizer)
            
        self.shuffle_dataloader = shuffle_dataloader
        self.batch_size = batch_size
        self.num_dataloader_workers = num_dataloader_workers
        # refactored from: https://www.kaggle.com/justinchae/nn-utils
    
    def _strip_extraneous(self, df):
        strip_cols = ['url_legal', 'license']
        if all(col in df.columns for col in strip_cols):
            extraneous_data = strip_cols
            return df.drop(columns=extraneous_data)
        else: 
            return df
    
    def prepare(self, prep_type=None):
        if prep_type == 'train':
            # creates just an instance of the train data as a pandas df
            self.train_df = self.train_path if isinstance(self.train_path, pd.DataFrame) else pd.read_csv(self.train_path, dtype=self.dtypes)
            self.train_df = self._strip_extraneous(self.train_df)
        
    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit':
            # when this class is called by trainer.fit, this stage runs and so on
            self.train_df = self.train_path if isinstance(self.train_path, pd.DataFrame) else pd.read_csv(self.train_path, dtype=self.dtypes)
            self.train_df = self._strip_extraneous(self.train_df)
            self.train_size = int(len(self.train_df))
            self.train_data = CommonLitDataset(df=self.train_df)
        
            if self.train_valid_split is not None and self.valid_path is None:
                self.train_size = int(len(self.train_df) * self.train_valid_split)
                self.train_data, self.valid_data = random_split(self.train_data, [self.train_size, len(self.train_df) - self.train_size])
            elif self.valid_path is not None:
                self.valid_df = self.valid_path if isinstance(self.valid_path, pd.DataFrame) else pd.read_csv(self.valid_path, dtype=self.dtypes)
                self.valid_data = CommonLitDataset(df=self.valid_df)
            
        if stage == 'predict':           
            self.test_df = self.test_path if isinstance(self.test_path, pd.DataFrame) else pd.read_csv(self.test_path, dtype=self.dtypes)
            self.test_df = self._strip_extraneous(self.test_df)
            self.test_data = CommonLitDataset(df=self.test_df)
    
    def kfold_data(self):
        # TODO: wondering how to integrate kfolds into the datamodule
        pass
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_data
                          , batch_size=self.batch_size
                          , shuffle=self.shuffle_dataloader
                          , collate_fn=self.collate_fn
                          , num_workers=self.num_dataloader_workers
                          , pin_memory=True
                          )
    def val_dataloader(self) -> DataLoader:
        if self.valid_data is None:
            return None
        else:
            return DataLoader(self.valid_data
                              , batch_size=self.batch_size
                              , shuffle=False
                              , collate_fn=self.collate_fn
                              , num_workers=self.num_dataloader_workers
                              , pin_memory=True
                              )
    def predict_dataloader(self) -> DataLoader:
        if self.test_data is None:
            return None
        else:
            return DataLoader(self.test_data
                              , batch_size=self.batch_size
                              , shuffle=False
                              , collate_fn=self.collate_fn
                              , num_workers=self.num_dataloader_workers
                              , pin_memory=True
                              ) 

In [None]:
def objective(trial: optuna.trial.Trial
              , logger=None
              , datamodule=None
              , train_df=None
              , valid_df=None) -> float:
    """Reference: https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_lightning_simple.py#L81
    """

    # propose params to study
    batch_size = BATCH_SIZE #trial.suggest_int("batch_size", 8, 16)
    # tried tuning batch size but ended up with better results by setting it to 16; ideally with larger capacity we go higher to 32
    tokenizer_max_len = trial.suggest_int("tokenizer_max_len", 128, 512)
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1)
    warmup_steps = trial.suggest_int("warmup_steps", 0, 1000)
    max_epochs = MAX_EPOCHS
    # also tried tuning epochs but seems to be too much to tune when considering everything else
    gradient_clip_val = trial.suggest_float("gradient_clip_val", 0, .5)
    weight_decay = trial.suggest_float("weight_decay", 0, 1e-2)
    # we can add more things to study here...
    
    checkpoint_filename = f'crp_roberta_trial_{trial.number}'
    checkpoint_save = ModelCheckpoint(dirpath=CHECKPOINTS_PATH
                                          , filename=checkpoint_filename
                                         )
    
    early_stopping_callback = EarlyStopping(monitor='val_loss'
                                            , patience=2
                                            )
    
    logger = pl.loggers.TensorBoardLogger(save_dir=LOGS_PATH) if logger is None else logger
    
    trainer = pl.Trainer(max_epochs=max_epochs
                         , logger=logger
                         , gpus=ACCELERATOR_TYPE['gpus']
                         , tpu_cores=ACCELERATOR_TYPE['tpu_cores']
                         , callbacks = [PyTorchLightningPruningCallback(trial, monitor='val_loss') 
                                        , checkpoint_save
                                        # come back and decide to use early stopping mid-epoch
#                                         , early_stopping_callback
                                        ]
                         , precision=16 if USE_16_BIT_PRECISION else 32
                         , default_root_dir=CHECKPOINTS_PATH
                         , gradient_clip_val=gradient_clip_val
                         , stochastic_weight_avg=True
                         # TODO: debug how to ensure trainer consumes stopped state
#                          , val_check_interval=.33
                         # TODO: debug how to auto scale batch search with optuna-lightning
#                          , auto_scale_batch_size='binsearch'
                        )
    # default to using the data module but allow for cross validation via train and valid loader objects
    if datamodule is None and train_df is None and valid_df is None:
        datamodule = CommonLitDataModule(collate_fn=my_collate_fn
                                         , tokenizer=PRETRAINED_ROBERTA_BASE_TOKENIZER
                                         , train_path=KAGGLE_TRAIN_PATH
                                         , test_path=KAGGLE_TEST_PATH
                                         , max_length=tokenizer_max_len
                                         , batch_size=batch_size
                                         , train_valid_split=TRAIN_VALID_SPLIT
                                          )
        # manually calling this stage since we need some params to set up model initially
        datamodule.setup(stage='fit')
        
    elif datamodule is None and train_df is not None and valid_df is not None:
        datamodule = CommonLitDataModule(collate_fn=my_collate_fn
                                         , tokenizer=PRETRAINED_ROBERTA_BASE_TOKENIZER
                                         , train_path=train_df
                                         , valid_path=valid_df
                                         , test_path=KAGGLE_TEST_PATH
                                         , max_length=tokenizer_max_len
                                         , batch_size=batch_size
                                         , train_valid_split=TRAIN_VALID_SPLIT
                                          )
        datamodule.setup(stage='fit')
    else:
        return False
        
    model = LitRobertaLogitRegressor(pre_trained_path=PRETTRAINED_ROBERTA_BASE_MODEL_PATH
                                      , train_size=datamodule.train_size
                                      , batch_size=datamodule.batch_size
                                      , n_gpus=trainer.gpus
                                      , n_tpus=trainer.tpu_cores
                                      , max_epochs=trainer.max_epochs
                                      , accumulate_grad_batches=trainer.accumulate_grad_batches
                                      , learning_rate=learning_rate
                                      , warmup_steps=warmup_steps
                                      )
    
    hyperparameters = dict(learning_rate=learning_rate
                           , warmup_steps=warmup_steps
                           , max_epochs=max_epochs
                           , gradient_clip_val=gradient_clip_val
                           , weight_decay=weight_decay
                           , batch_size=batch_size
                           , tokenizer_max_len=tokenizer_max_len
                          )
    
    trainer.logger.log_hyperparams(hyperparameters)
    trainer.fit(model, datamodule=datamodule)
    
    # saving the fine-tuned states of roberta transformers
#     model_file_name = f"trial_{trial.number}_tuned_roberta"
#     model_file_path = os.path.join(MODEL_PATH, model_file_name)
#     model.save_pretrained(model_file_path)
    # it turns out that we don't actually need to save the RoBERTa model weights seperately, that's all in the checkpoint

    curr_loss = trainer.callback_metrics['val_loss'].item()
        
    return curr_loss

In [None]:
def objective_cv(trial):

    # wrap a cross validation dataset around the objective function
    # source: https://stackoverflow.com/questions/63224426/how-can-i-cross-validate-by-pytorch-and-optuna
    # we can call this function instead of objective to run kfolds in each trial
    # FIXME: run k folds within each trial without incremending the objective counter
    # Issue: Each iteration of a trial increases disk usage and with kaggle we run out of memory
    
    datamodule = CommonLitDataModule(tokenizer=PRETRAINED_ROBERTA_BASE_TOKENIZER
                                    , train_path=KAGGLE_TRAIN_PATH
                                     )
    
    datamodule.prepare(prep_type='train')
    train_df = datamodule.train_df

    fold = KFold(n_splits=3, shuffle=True, random_state=SEED_VAL)
    scores = []
    
    for fold_idx, (train_idx, valid_idx) in enumerate(fold.split(range(len(train_df)))):
        # clean up memory
        torch.cuda.empty_cache()
        gc.collect()

        train_data = train_df.iloc[train_idx]
        valid_data = train_df.iloc[valid_idx]
        # pass data objects to objective and return average of losses
        losses = objective(trial
                             , train_df=train_data
                             , valid_df=valid_data)
        print(f"=== kfold: {fold_idx + 1}: val_loss: {accuracy} \n")
        scores.append(losses)
        
    return np.mean(scores)

In [None]:
if __name__ == '__main__':
    # necessary to clean up space when running cells repeatedly in kernel
    torch.cuda.empty_cache()
    gc.collect()
    # testing the hyperbandpruner which is supposed to work better
#     pruner = optuna.pruners.HyperbandPruner(min_resource=2)
    pruner = optuna.pruners.MedianPruner(n_startup_trials=2
                                         # steps are epochs, this ensures trials are not pruned until second epoch
                                        , n_warmup_steps=2
                                        )
    sampler = TPESampler(multivariate=True
                         , seed=SEED_VAL)
    # testing some other samplers which requires some trial and error
#     sampler = RandomSampler(seed=SEED_VAL)
    # params per docs, re CMA-ES sampler with hyperband pruner
#     sampler = CmaEsSampler(consider_pruned_trials=True
#                            , n_startup_trials=
#                            , seed=SEED_VAL
#                            , restart_strategy='ipop')
    
    study = optuna.create_study(study_name="crp-roberta-tuning"
                                , direction="minimize"
                                , pruner=pruner
                                , sampler=sampler
                                # testing study storage in rdb instead of memory
                                , storage='sqlite:///crp-study.db'
                               )
    
    study.optimize(objective
                   # or objective_cv for cross validator
                   # exceeding memory constraints past 10-12 trials in kaggle
                   , n_trials=N_OPTUNA_TRIALS
                   # if using timeout, trials end at time instead of running all expected trials
#                    , timeout=600
#                    , gc_after_trial=True
                   # test if callbacks with lambda manages memory instead
                   , callbacks=[lambda study, trial: gc.collect()]
                  )

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    best_trial = study.best_trial

    print("  Value: {}".format(best_trial.value))

    print("  Params: ")
    for key, value in best_trial.params.items():
        print("    {}: {}".format(key, value))
        
    print("  Full Summary of Trials:  ")
    print(study.trials_dataframe())
    
    plot_optimization_history(study).show()
    plot_intermediate_values(study).show()
    try:
        plot_param_importances(study).show()
    except ValueError:
        pass
        
    #TODO: link logs to offline graphs, i.e. tensorboard UI
    #TODO: link logs to neptune for real time  awareness
    #TODO: figure out an easy way to toggle TPU and GPU usage
    #TODO: debug early stopping - early stop during optuna seems to work but the state of early stop does not seem to persist

In [None]:
# testing restoring from checkpoints (buggy)
# clean up memory
torch.cuda.empty_cache()
gc.collect()

# # recall the best trial checkpoint
best_trial_checkpoint = os.path.join(CHECKPOINTS_PATH, f'crp_roberta_trial_{best_trial.number}.ckpt')

# we can recall the roberta fine tuned state if we want
# tuned_roberta_file_name = f"trial_{best_trial.number}_tuned_roberta"
# tuned_roberta_file_path = os.path.join(MODEL_PATH, tuned_roberta_file_name)

# TODO: organize code so we don't have to call it in optuna study and then again here
crp_data = CommonLitDataModule(collate_fn=my_collate_fn
                               , tokenizer=PRETRAINED_ROBERTA_BASE_TOKENIZER
                               , train_path=KAGGLE_TRAIN_PATH
                               , test_path=KAGGLE_TEST_PATH
                               , max_length=best_trial.params['tokenizer_max_len']
                               , batch_size=BATCH_SIZE #best_trial.params['batch_size']
                              )

crp_data.setup(stage='predict')
# restore the model checkpoint of the best trial
model = LitRobertaLogitRegressor.load_from_checkpoint(best_trial_checkpoint
                                                    # so here, we don't have to load the tuned roberta model, its all in the checkpoint file
#                                                        , pre_trained_path=tuned_roberta_file_path
                                                     )
# freeze the model for prediction
model.eval()
model.freeze()

# set up a new trainer object to run prediction
trainer = pl.Trainer(gpus=ACCELERATOR_TYPE['gpus']
                     , tpu_cores=ACCELERATOR_TYPE['tpu_cores']
                     )

# run predict on the test data
predictions = trainer.predict(model=model, datamodule=crp_data)

submission = pd.concat(predictions).reset_index(drop=True)

print(submission)
submission.to_csv('submission.csv', index=False)

# TODO: test whether we need to save and upload the fine-tuned state of roberta or if pytorch lightning checkpoints take care of it all

#### Helpful Resources

* Optuna Docs: [https://optuna.readthedocs.io/en/stable/index.html](https://optuna.readthedocs.io/en/stable/index.html)

* PyTorch Lightning Docs: [https://pytorch-lightning.readthedocs.io/en/latest/](https://pytorch-lightning.readthedocs.io/en/latest/)

* For learning rate tuning: [https://medium.com/pytorch/using-optuna-to-optimize-pytorch-hyperparameters-990607385e36](https://medium.com/pytorch/using-optuna-to-optimize-pytorch-hyperparameters-990607385e36)

* For PyTorch Lightning Precision: [https://pytorch-lightning.readthedocs.io/en/stable/advanced/amp.html](https://pytorch-lightning.readthedocs.io/en/stable/advanced/amp.html)

* For PyTorch Lightning Early Stopping: [https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html](https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html)

* For PyTorch Lightning Checkpointing: [https://pytorch-lightning.readthedocs.io/en/stable/common/weights_loading.html](https://pytorch-lightning.readthedocs.io/en/stable/common/weights_loading.html)

* BERT Example from PyTorch Lighting: [https://pytorch-lightning.readthedocs.io/en/stable/advanced/transfer_learning.html](https://pytorch-lightning.readthedocs.io/en/stable/advanced/transfer_learning.html)

* Fine-Tuning a Transformer from PyTorch Lightning: [https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/text-transformers.html](https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/text-transformers.html)

* Example of Optuna with PyTorch Lightning: [https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_lightning_simple.py](https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_lightning_simple.py)

* For PyTorch Lightning Logging: [https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html](https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html)

* For Predict Mode with PyTorch Lightning: [https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html)

* Restoring Checkpoints and Continuing Training: [https://pytorch-lightning.readthedocs.io/en/latest/common/weights_loading.html?highlight=checkpoint#checkpoint-loading](https://pytorch-lightning.readthedocs.io/en/latest/common/weights_loading.html?highlight=checkpoint#checkpoint-loading)

* Gradient Clipping in PyTorch Lightning: [https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html?highlight=memory#advanced-gpu-optimizations](https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html?highlight=memory#advanced-gpu-optimizations)

* How to approach trial suggestions in Optuna: [https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html?highlight=suggest#optuna.trial.Trial.suggest_int](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html?highlight=suggest#optuna.trial.Trial.suggest_int)

* Guidance on Early Stopping Callbacks with PyTorch Lightning: [https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.early_stopping.html#pytorch_lightning.callbacks.early_stopping.EarlyStopping](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.early_stopping.html#pytorch_lightning.callbacks.early_stopping.EarlyStopping)

* For Reproducible Optuna Studies: [https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-obtain-reproducible-optimization-results](https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-obtain-reproducible-optimization-results)

* Guidance on Optuna Pruners: [https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.HyperbandPruner.html#optuna.pruners.HyperbandPruner](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.HyperbandPruner.html#optuna.pruners.HyperbandPruner)

* More guidance on which Optuna Pruners to use based on ML task: [https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/003_efficient_optimization_algorithms.html?highlight=memory#activating-pruners](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/003_efficient_optimization_algorithms.html?highlight=memory#activating-pruners)

* TPUs [https://www.kaggle.com/justusschock/pytorch-on-tpu-with-pytorch-lightning](https://www.kaggle.com/justusschock/pytorch-on-tpu-with-pytorch-lightning)

* For Neptune to PyTorch Lightning Integration: [https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning](https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning)