In [None]:
import logging

import config
import lightgbm as lgb
import polars as pl
from lifelines import KaplanMeierFitter
from lightgbm import LGBMModel

from src.customs.fold import add_kfold
from src.customs.metrics import LGBMMetric, Metric
from src.feature.tabular import OrdinalEncoder, RawEncoder
from src.model.sklearn_like import LightGBMWapper
from src.trainer.tabular.simple import single_inference_fn, single_train_fn

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [None]:
class CreateTargetFn:
    def __init__(self, time_col: str, event_col: str):
        self.time_col = time_col
        self.event_col = event_col
        self.target_cols = [
            "target_kmf_race_group",
            "target_kmf",
            "target_xgboost_cox",
        ]

    def __call__(self, df: pl.DataFrame) -> pl.DataFrame:
        # add kmf target by race_group
        dfs = []
        for _, i_df in df.group_by("race_group", maintain_order=True):
            kmf = KaplanMeierFitter()
            kmf.fit(i_df[self.time_col], event_observed=i_df[self.event_col])
            _df = i_df.with_columns(
                pl.Series("target_kmf_race_group", kmf.survival_function_at_times(i_df[self.time_col])),
            )
            dfs.append(_df)
        target_df = pl.concat(dfs)

        # add kmf target
        kmf = KaplanMeierFitter()
        kmf.fit(target_df[self.time_col], event_observed=target_df[self.event_col])
        target_df = target_df.with_columns(
            pl.Series("target_kmf", kmf.survival_function_at_times(target_df[self.time_col])),
        )

        # add xgboost cox target
        target_df = target_df.with_columns(
            pl.when(target_df[self.event_col] == 0)
            .then(-1 * target_df[self.time_col])
            .otherwise(target_df[self.time_col])
            .alias("target_xgboost_cox"),
        )

        return target_df


raw_train_df = pl.read_csv(config.COMP_DATASET_DIR / "train.csv").with_columns(
    pl.lit("TRAIN").alias(config.DATASET_COL),
    pl.lit(-1).alias(config.FOLD_COL),
)

create_target_fn = CreateTargetFn(time_col=config.SURVIVAL_TIME_COL, event_col=config.EVENT_COL)
raw_train_df = create_target_fn(raw_train_df)
raw_test_df = pl.read_csv(config.COMP_DATASET_DIR / "test.csv").with_columns(pl.lit("TEST").alias(config.DATASET_COL))
config.META_COLS = sorted(list(set(config.META_COLS + create_target_fn.target_cols)))

train_test_df = pl.concat([raw_train_df, raw_test_df], how="diagonal_relaxed")

In [None]:
encoders = [
    RawEncoder(columns=config.META_COLS, prefix=""),
    RawEncoder(
        columns=(
            [
                *config.NUMERICAL_COLS,
            ]
        ),
        prefix=f"{config.FEATURE_PREFIX}n_",
    ),
    OrdinalEncoder(
        columns=(
            [
                *config.CATEGORICAL_COLS,
            ]
        ),
        prefix=f"{config.FEATURE_PREFIX}c_",
    ),
]

for encoder in encoders:
    encoder.fit(raw_train_df)

train_df = train_test_df.filter(pl.col(config.DATASET_COL) == "TRAIN")
train_features_df = pl.concat(
    [encoder.transform(train_df) for encoder in encoders],
    how="horizontal",
)

feature_names = sorted([x for x in train_features_df.columns if x.startswith(config.FEATURE_PREFIX)])
cat_features = [x for x in feature_names if x.startswith(f"{config.FEATURE_PREFIX}c_")]

logger.info(f"# of features: {len(feature_names)}")
logger.info(f"# of cat_features: {len(cat_features)}")

In [None]:
train_features_df = add_kfold(
    train_features_df,
    n_splits=config.N_SPLITS,
    random_state=config.N_SPLITS,
    fold_col=config.FOLD_COL,
)

va_result_df, va_scores, trained_models = single_train_fn(
    model=LightGBMWapper(
        name="lgb",
        model=LGBMModel(
            objective="rmse",
            boosting="gbdt",
            n_estimators=10000,
            learning_rate=0.01,
            num_leaves=31,
            colsample_bytree=0.2,
            subsample=0.5,
            importance_type="gain",
            metric="None",
        ),
        fit_params={
            "callbacks": [
                lgb.early_stopping(500, first_metric_only=True),
                lgb.log_evaluation(period=100),
            ],
            "eval_metric": LGBMMetric(),
            "categorical_feature": cat_features,
            "feature_name": feature_names,
        },
    ),
    features_df=train_features_df,
    feature_cols=feature_names,
    target_col="target_kmf",
    fold_col=config.FOLD_COL,
    meta_cols=config.META_COLS,
    out_dir=config.OUTPUT_DIR,
    train_folds=None,
    eval_fn=Metric(),
    overwrite=True,
    use_eval_metric_extra_va_df=True,
)

In [None]:
test_df = train_test_df.filter(pl.col(config.DATASET_COL) == "TEST")
test_features_df = pl.concat(
    [encoder.transform(test_df) for encoder in encoders],
    how="horizontal",
)
test_preds = single_inference_fn(
    model=LightGBMWapper(name="lgb"),
    features_df=test_features_df,
    feature_names=feature_names,
    model_dir=config.ARTIFACT_EXP_DIR(),
    inference_folds=list(range(config.N_SPLITS)),
)

In [None]:
test_features_df.select(config.ID_COL).with_columns(pl.Series("prediction", test_preds)).write_csv(
    config.OUTPUT_DIR / "submission.csv"
)