In [None]:
import logging
import os
import gc
from pathlib import Path
import json

import joblib
import torch
import wandb
import pandas as pd
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from kaggle import KaggleApi

from lightning import seed_everything
from lightning.pytorch.utilities.memory import garbage_collection_cuda
from sklearn.model_selection import BaseCrossValidator

from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.kaggle_utils import download_kaggle_competition_dataset, download_kaggle_datasets

from src.utils.metrics import binary_classification_metrics
from lightning.pytorch.plugins import BitsandbytesPrecision

In [None]:
precision = BitsandbytesPrecision(mode="nf4-dq")

In [None]:
OVERRIDES: list[str] = os.getenv("OVERRIDES", "experiment=000-finetune").split(",")
WANDB_KEY = os.getenv("WANDB_KEY", None)  # input your wandb key as environment variable

In [None]:
if OVERRIDES is None:
    raise ValueError("OVERRIDES is not set")

with initialize(version_base=None, config_path="../../configs"):
    CFG = compose(
        config_name="config.yaml",
        return_hydra_config=True,
        overrides=OVERRIDES,
    )
    HydraConfig.instance().set_config(CFG)  # use HydraConfig for notebook to use hydra job

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

if not logger.handlers:
    handler = logging.StreamHandler()
    logger.addHandler(handler)

KAGGLE_CLIENT = KaggleApi()
KAGGLE_CLIENT.authenticate()

INPUT_DIR = Path(CFG.paths.input_dir)

logger.info(f"start {OVERRIDES} 🚀")
seed_everything(CFG.seed)
wandb.login(key=WANDB_KEY)

### Load Data


In [None]:
download_kaggle_competition_dataset(
    client=KAGGLE_CLIENT,
    competition=CFG.meta.competition,
    out_dir=Path(CFG.paths.input_dir),
)

download_kaggle_datasets(
    client=KAGGLE_CLIENT,
    datasets=CFG.kaggle.external_datasets,
    out_dir=INPUT_DIR,
)

In [None]:
# external dataset
train_df = pd.read_csv(INPUT_DIR / "thedrcat/daigt-v2-train-dataset/train_v2_drcat_02.csv")

### Preprocess


In [None]:
if CFG.debug:
    train_df = train_df.sample(100, random_state=CFG.seed).reset_index(drop=True)
    if "debug" not in CFG.lightning.logger.wandb.group:
        CFG.lightning.logger.wandb.group = CFG.experiment_name + "_debug"

logger.debug(f"train shape : {train_df.shape}")
logger.debug(f"train generated label : {train_df['label'].sum()}")

### CV Split


In [None]:
def assign_fold_index(train_df: pd.DataFrame, kfold: BaseCrossValidator) -> pd.DataFrame:
    train_df["fold"] = -1
    for fold_index, (_, valid_index) in enumerate(kfold.split(X=train_df, y=train_df["label"])):
        train_df.loc[valid_index, "fold"] = fold_index
    return train_df


kfold = instantiate(CFG.cv)
train_df = assign_fold_index(train_df=train_df, kfold=kfold)

### Training


In [None]:
result_dfs = []
base_output_dir = Path(CFG.paths.output_dir)  # store output_dir for later use

for i_fold in range(CFG.n_splits):
    if i_fold not in CFG.train_folds:
        continue

    i_train_df = train_df.query(f"fold != {i_fold}").reset_index(drop=True)
    i_valid_df = train_df.query(f"fold == {i_fold}").reset_index(drop=True)

    logger.info(f"# --------------- # start training fold={i_fold} 🚀 # --------------- # ")
    CFG.paths.output_dir = str(base_output_dir / f"fold{i_fold}")
    CFG.lightning.logger.wandb.name = f"fold{i_fold}"

    logger.debug(f"lightning trainer default_root_dir : {CFG.lightning.trainer.default_root_dir}")
    logger.debug(f"lightning callbacks model_checkpoint dirpath : {CFG.lightning.callbacks.model_checkpoint.dirpath}")

    # instantiate lightning module, datamodule and trainer by fold
    logger.info(f"Instantiating lightning module <{CFG.lightning.model._target_}>")
    lt_module = instantiate(CFG.lightning.model)

    logger.info(f"Instantiating lightning datamodule <{CFG.lightning.data.lt_datamodule._target_}>")
    train_dataset = instantiate(CFG.lightning.data.train_dataset, df=i_train_df)
    val_dataset = instantiate(CFG.lightning.data.val_dataset, df=i_valid_df)
    lt_datamodule = instantiate(CFG.lightning.data.lt_datamodule, train_dataset=train_dataset, val_dataset=val_dataset)

    logger.info(f"Instantiating lightning trainer <{CFG.lightning.trainer}>")
    lt_logger = instantiate_loggers(CFG.lightning.logger)
    callbacks = instantiate_callbacks(CFG.lightning.callbacks)
    trainer = instantiate(CFG.lightning.trainer, logger=lt_logger, callbacks=callbacks)

    ckpt_path = None
    if CFG.ckpt_path is not None:
        ckpt_path = CFG.ckpt_path.format(fold=i_fold)

    trainer.fit(model=lt_module, datamodule=lt_datamodule, ckpt_path=ckpt_path)
    val_predictions = trainer.predict(model=lt_module, datamodule=lt_datamodule, ckpt_path="best")
    val_predictions = (torch.concatenate(val_predictions)).sigmoid().float().numpy().reshape(-1)

    # save dataframe assigned validation predictions
    i_result_df = i_valid_df.assign(pred=val_predictions)
    joblib.dump(i_result_df, Path(CFG.paths.output_dir) / "val_predictions.pkl")
    result_dfs.append(i_result_df)

    # evaluate
    socres = binary_classification_metrics(y_true=i_valid_df["label"], y_pred=val_predictions)
    json.dump(socres, open(Path(CFG.paths.output_dir) / "valid_scores.json", "w"))
    logger.info(f"fold{i_fold} scores: {socres}")

    del i_result_df, i_valid_df, trainer, train_dataset, val_dataset, lt_datamodule
    gc.collect()
    torch.cuda.empty_cache()
    garbage_collection_cuda()

    # save only best weights
    best_weights_path = Path(CFG.paths.output_dir) / "weights" / "best.pth"
    best_weights_path.parent.mkdir(parents=True, exist_ok=True)
    lt_module.load_state_dict(
        torch.load(Path(CFG.lightning.callbacks.model_checkpoint.dirpath) / "best.ckpt")["state_dict"]
    )
    torch.save(lt_module.net.state_dict(), best_weights_path)

    del lt_module
    gc.collect()
    torch.cuda.empty_cache()
    garbage_collection_cuda()


CFG.paths.output_dir = str(base_output_dir)  # restore output_dir
valid_results_df = pd.concat(result_dfs, axis=0).reset_index(drop=True)
valid_results_df.to_csv(Path(CFG.paths.output_dir) / "valid_results.csv", index=False)

In [None]:
valid_results_df