In [None]:
import logging
import os
import re
import gc
from pathlib import Path

import joblib
import torch
import wandb
import pandas as pd
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from kaggle import KaggleApi

from lightning import seed_everything
from lightning.pytorch.utilities.memory import garbage_collection_cuda
from sklearn.model_selection import BaseCrossValidator

from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.kaggle_utils import download_kaggle_competition_dataset, download_kaggle_datasets

In [None]:
EXPERIMENT = os.getenv("EXPERIMENT", "000")  # input your experiment number as environment variable
WANDB_KEY = os.getenv("WANDB_KEY", None)  # input your wandb key as environment variable

In [None]:
if EXPERIMENT is None:
    raise ValueError("EXPERIMENT is not set")

with initialize(version_base=None, config_path="../../configs"):
    CFG = compose(
        config_name="config.yaml",
        return_hydra_config=True,
        overrides=[f"experiment={EXPERIMENT}"],
    )
    HydraConfig.instance().set_config(CFG)  # use HydraConfig for notebook to use hydra job

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

if not logger.handlers:
    handler = logging.StreamHandler()
    logger.addHandler(handler)

KAGGLE_CLIENT = KaggleApi()
KAGGLE_CLIENT.authenticate()

INPUT_DIR = Path(CFG.paths.input_dir)

logger.info(f"start experiment={EXPERIMENT} 🚀")
seed_everything(CFG.seed)
wandb.login(key=WANDB_KEY)

### Load Data


In [None]:
download_kaggle_competition_dataset(
    client=KAGGLE_CLIENT,
    competition=CFG.meta.competition,
    out_dir=Path(CFG.paths.input_dir),
)

download_kaggle_datasets(
    client=KAGGLE_CLIENT,
    datasets=CFG.kaggle.external_datasets,
    out_dir=INPUT_DIR,
)

In [None]:
# competition dataset
raw_train_essays_df = pd.read_csv(INPUT_DIR / "train_essays.csv")
raw_train_prompts_df = pd.read_csv(INPUT_DIR / "train_prompts.csv")

# external dataset
ex01_train_essays_df = pd.read_csv(INPUT_DIR / "radek1/llm-generated-essays/ai_generated_train_essays.csv")
ex02_train_essays_df = pd.read_csv(INPUT_DIR / "radek1/llm-generated-essays/ai_generated_train_essays_gpt-4.csv")
ex03_train_essays_df = pd.read_csv(INPUT_DIR / "alejopaullier/daigt-external-dataset/daigt_external_dataset.csv")

### Preprocess


In [None]:
def remove_brackets_sentence(text: str) -> str:
    pattern = re.compile(r"^.*\[.*\].*$", re.MULTILINE)
    return re.sub(pattern, "", text)


def remove_title_sentence(text: str) -> str:
    pattern = re.compile(r"D^[ \t]*title:.*$", re.IGNORECASE | re.MULTILINE)
    result = re.sub(pattern, "", text)
    result = re.sub(r"\n{2,}", "\n", result)
    return result.strip()


def remove_quotes(text: str) -> str:
    return re.sub(r'[\'"]', "", text)


def clean_text(text: str) -> pd.DataFrame:
    cleanse_fns = [
        remove_brackets_sentence,
        remove_title_sentence,
        remove_quotes,
    ]
    for fn in cleanse_fns:
        text = fn(text)
    return text.lower()


def make_train_df(
    raw_train_essays_df: pd.DataFrame,
    raw_train_prompts_df: pd.DataFrame,
    ex01_train_essays_df: pd.DataFrame,  # radek1 dataset
    ex02_train_essays_df: pd.DataFrame,  # radek1 dataset
    ex03_train_essays_df: pd.DataFrame,  # alejopaullier dataet
) -> pd.DataFrame:
    train_df = pd.concat(
        [
            raw_train_essays_df,
            ex01_train_essays_df,
            ex02_train_essays_df,
        ],
        axis=0,
    ).reset_index(drop=True)

    train_df = train_df.merge(raw_train_prompts_df, on="prompt_id", how="left")  # join raw prompts
    train_df = pd.concat(
        [
            train_df,
            ex03_train_essays_df.assign(generated=1),  # concat alejopaullier dataset
        ],
        axis=0,
    ).reset_index(drop=True)

    train_df["cleansed_text"] = train_df["text"].apply(clean_text)
    train_df["cleansed_source_text"] = train_df["source_text"].apply(clean_text)
    return train_df

In [None]:
train_df = make_train_df(
    raw_train_essays_df=raw_train_essays_df,
    raw_train_prompts_df=raw_train_prompts_df,
    ex01_train_essays_df=ex01_train_essays_df,
    ex02_train_essays_df=ex02_train_essays_df,
    ex03_train_essays_df=ex03_train_essays_df,
)

if CFG.debug:
    train_df = train_df.sample(10, random_state=CFG.seed).reset_index(drop=True)
    if "debug" not in CFG.lightning.logger.wandb.group:
        CFG.lightning.logger.wandb.group = CFG.experiment_name + "_debug"

logger.debug(f"train shape : {train_df.shape}")
logger.debug(f"train generated label : {train_df['generated'].sum()}")

### CV Split


In [None]:
def assign_fold_index(train_df: pd.DataFrame, kfold: BaseCrossValidator) -> pd.DataFrame:
    train_df["fold"] = -1
    for fold_index, (_, valid_index) in enumerate(kfold.split(X=train_df, y=train_df["generated"])):
        train_df.loc[valid_index, "fold"] = fold_index
    return train_df


kfold = instantiate(CFG.cv)
train_df = assign_fold_index(train_df=train_df, kfold=kfold)

### Training


In [None]:
result_dfs = []
base_output_dir = Path(CFG.paths.output_dir)  # store output_dir for later use

for i_fold in range(CFG.n_splits):
    if i_fold not in CFG.train_folds:
        continue

    i_train_df = train_df.query(f"fold != {i_fold}").reset_index(drop=True)
    i_valid_df = train_df.query(f"fold == {i_fold}").reset_index(drop=True)

    logger.info(f"# --------------- # start training fold={i_fold} 🚀 # --------------- # ")
    CFG.paths.output_dir = str(base_output_dir / f"fold{i_fold}")
    CFG.lightning.logger.wandb.name = f"fold{i_fold}"

    logger.debug(f"lightning trainer default_root_dir : {CFG.lightning.trainer.default_root_dir}")
    logger.debug(f"lightning callbacks model_checkpoint dirpath : {CFG.lightning.callbacks.model_checkpoint.dirpath}")

    # instantiate lightning module, datamodule and trainer by fold
    logger.info(f"Instantiating lightning module <{CFG.lightning.model._target_}>")
    lt_module = instantiate(CFG.lightning.model)

    logger.info(f"Instantiating lightning datamodule <{CFG.lightning.data.lt_datamodule._target_}>")
    train_dataset = instantiate(CFG.lightning.data.train_dataset, df=i_train_df)
    val_dataset = instantiate(CFG.lightning.data.val_dataset, df=i_valid_df)
    lt_datamodule = instantiate(CFG.lightning.data.lt_datamodule, train_dataset=train_dataset, val_dataset=val_dataset)

    logger.info(f"Instantiating lightning trainer <{CFG.lightning.trainer}>")
    lt_logger = instantiate_loggers(CFG.lightning.logger)
    callbacks = instantiate_callbacks(CFG.lightning.callbacks)
    trainer = instantiate(CFG.lightning.trainer, logger=lt_logger, callbacks=callbacks)

    ckpt_path = None
    if CFG.ckpt_path is not None:
        ckpt_path = CFG.ckpt_path.format(fold=i_fold)

    trainer.fit(model=lt_module, datamodule=lt_datamodule, ckpt_path=ckpt_path)
    val_predictions = trainer.predict(model=lt_module, datamodule=lt_datamodule, ckpt_path="best")

    # save dataframe assigned validation predictions
    i_result_df = i_valid_df.assign(pred=(torch.concatenate(val_predictions)).sigmoid().float().numpy())
    joblib.dump(i_result_df, Path(CFG.paths.output_dir) / "val_predictions.pkl")
    result_dfs.append(i_result_df)

    del i_result_df, i_valid_df, trainer, train_dataset, val_dataset, lt_datamodule
    gc.collect()
    torch.cuda.empty_cache()
    garbage_collection_cuda()

    # save only best weights
    best_weights_path = Path(CFG.paths.output_dir) / "weights" / "best.pth"
    best_weights_path.parent.mkdir(parents=True, exist_ok=True)
    lt_module.load_state_dict(
        torch.load(Path(CFG.lightning.callbacks.model_checkpoint.dirpath) / "best.ckpt")["state_dict"]
    )
    torch.save(lt_module.net.state_dict(), best_weights_path)

    del lt_module
    gc.collect()
    torch.cuda.empty_cache()
    garbage_collection_cuda()


CFG.paths.output_dir = str(base_output_dir)  # restore output_dir
valid_results_df = pd.concat(result_dfs, axis=0).reset_index(drop=True)
valid_results_df.to_csv(Path(CFG.paths.output_dir) / "valid_results.csv", index=False)