In [None]:
import logging
import os
import gc
from pathlib import Path
import json

import joblib
import torch
import wandb
import pandas as pd
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from kaggle import KaggleApi

from lightning import seed_everything
from sklearn.model_selection import BaseCrossValidator

from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.kaggle_utils import download_kaggle_competition_dataset, download_kaggle_datasets

from src.utils.metrics import binary_classification_metrics

In [None]:
OVERRIDES: list[str] = os.getenv("OVERRIDES", "experiment=000-finetune").split(",")
WANDB_KEY = os.getenv("WANDB_KEY", None)  # input your wandb key as environment variable

In [None]:
if OVERRIDES is None:
    raise ValueError("OVERRIDES is not set")

with initialize(version_base=None, config_path="../../configs"):
    CFG = compose(
        config_name="config.yaml",
        return_hydra_config=True,
        overrides=OVERRIDES,
    )
    HydraConfig.instance().set_config(CFG)  # use HydraConfig for notebook to use hydra job

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

if not logger.handlers:
    handler = logging.StreamHandler()
    logger.addHandler(handler)

KAGGLE_CLIENT = KaggleApi()
KAGGLE_CLIENT.authenticate()

INPUT_DIR = Path(CFG.paths.input_dir)

logger.info(f"start {OVERRIDES} 🚀")
seed_everything(CFG.seed)
wandb.login(key=WANDB_KEY)

### Load Data


In [None]:
download_kaggle_competition_dataset(
    client=KAGGLE_CLIENT,
    competition=CFG.meta.competition,
    out_dir=Path(CFG.paths.input_dir),
)

download_kaggle_datasets(
    client=KAGGLE_CLIENT,
    datasets=CFG.kaggle.external_datasets,
    out_dir=INPUT_DIR,
)

In [None]:
# external dataset
train_df = pd.read_csv(INPUT_DIR / "thedrcat/daigt-v2-train-dataset/train_v2_drcat_02.csv")

### Preprocess


In [None]:
if CFG.debug:
    train_df = train_df.sample(100, random_state=CFG.seed).reset_index(drop=True)
    if "debug" not in CFG.lightning.logger.wandb.group:
        CFG.lightning.logger.wandb.group = CFG.experiment_name + "_debug"

logger.debug(f"train shape : {train_df.shape}")
logger.debug(f"train generated label : {train_df['label'].sum()}")

### CV Split


In [None]:
def assign_fold_index(train_df: pd.DataFrame, kfold: BaseCrossValidator) -> pd.DataFrame:
    train_df["fold"] = -1
    for fold_index, (_, valid_index) in enumerate(kfold.split(X=train_df, y=train_df["label"])):
        train_df.loc[valid_index, "fold"] = fold_index
    return train_df


kfold = instantiate(CFG.cv)
train_df = assign_fold_index(train_df=train_df, kfold=kfold)