In [None]:
import datetime
import json
import logging

import lightgbm as lgb
import numpy as np
import polars as pl
import rootutils
import seaborn as sns
from joblib import Memory
from jpholiday import JPHoliday
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from tqdm import tqdm

sns.set_style("whitegrid")
logging.basicConfig(level=logging.INFO)

ROOT = rootutils.setup_root(".", pythonpath=True, cwd=True)

from src.feature.tabular import AggregateEncoder, OrdinalEncoder, RawEncoder
from src.model.sklearn_like import LightGBMWapper
from src.trainer.tabular.simple import single_inference_fn_v2, single_train_fn

DATA_DIR = ROOT / "data"
INPUT_DIR = DATA_DIR / "atmacup19_dataset"
OUTPUT_DIR = DATA_DIR / "output"
CACHE_DIR = DATA_DIR / "cache"

for d in [DATA_DIR, INPUT_DIR, OUTPUT_DIR, CACHE_DIR]:
    d.mkdir(exist_ok=True, parents=True)

pl.Config.set_fmt_str_lengths(200)
pl.Config.set_tbl_cols(50)
pl.Config.set_tbl_rows(50)

memory = Memory(CACHE_DIR, verbose=0)
jpholiday = JPHoliday()


In [None]:
EXP_NAME = "018"

SEED = 42
SEEDS = [x + SEED for x in range(1)]  # seed averaging
N_SPLITS = 1  # (N_SPLITS = 1): holdout

DEBUG = False
VALID_SAMPLE_FRAC = 1
VALID_DATE = "2024-10-01"


In [None]:
EXP_OUTPUT_DIR = OUTPUT_DIR / EXP_NAME
EXP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

FEATURE_PREFIX = "f_"
NUMERICAL_FEATURE_PREFIX = f"{FEATURE_PREFIX}n_"
CATEGORICAL_FEATURE_PREFIX = f"{FEATURE_PREFIX}c_"

TARGET_COLS = [
    "ヘアケア",
    "チョコレート",
    "ビール",
    "米（5㎏以下）",
]
META_COLS = [
    "session_id",
    "顧客CD",
    "session_datetime",
    "original_顧客CD",
] + TARGET_COLS

FOLD_COL = "fold"


### load data


In [None]:
ec_log_df = pl.read_csv(INPUT_DIR / "ec_log.csv", infer_schema_length=200000)
jan_df = pl.read_csv(INPUT_DIR / "jan.csv")
test_session_df = pl.read_csv(INPUT_DIR / "test_session.csv")
train_session_df = pl.read_csv(INPUT_DIR / "train_session.csv")
train_log_df = pl.read_csv(INPUT_DIR / "train_log.csv")

# create target
train_session_df = (
    train_session_df.select("session_id")
    .join(train_log_df, on="session_id", how="left")
    .join(jan_df.filter(pl.col("カテゴリ名").is_in(TARGET_COLS)), on="JAN", how="left")
    .select(["session_id", "カテゴリ名", "売上数量"])
    .group_by(["session_id", "カテゴリ名"])
    .agg(pl.col("売上数量").sum())
    .filter(pl.col("売上数量") > 0)
    .with_columns(pl.lit(1).alias("target"))
    .pivot("カテゴリ名", index="session_id", values="target")
    .select(["session_id"] + TARGET_COLS)
    .join(train_session_df, on="session_id", how="right")
    .with_columns(pl.col(x).fill_null(0) for x in TARGET_COLS)
)

# validation: test の顧客CD は unique
valid_session_id = (
    train_session_df.sort("売上日")
    .group_by("顧客CD")
    .tail(1)
    .filter(pl.col("売上日") >= VALID_DATE)["session_id"]
    .unique()
    .sample(fraction=VALID_SAMPLE_FRAC, seed=SEED)
)
valid_session_df = train_session_df.filter(pl.col("session_id").is_in(valid_session_id))
train_session_df = train_session_df.filter(pl.col("session_id").is_in(valid_session_id).not_())

# sampling
if DEBUG:
    train_session_df = train_session_df.sample(10000, seed=SEED)


raw_full_session_df = (
    pl.concat(
        [
            train_session_df.with_columns(dataset=pl.lit("TRAIN")),
            valid_session_df.with_columns(dataset=pl.lit("VALID")),
            test_session_df.with_columns(dataset=pl.lit("TEST")),
        ],
        how="diagonal_relaxed",
    )
    .with_columns(pl.col("売上日").cast(pl.Date))
    .with_columns(
        pl.datetime(
            pl.col("売上日").dt.year(), pl.col("売上日").dt.month(), pl.col("売上日").dt.day(), pl.col("時刻")
        ).alias("session_datetime")
    )
)

print(f"train_session_df: {train_session_df.shape}")
print(f"valid_session_df: {valid_session_df.shape}")
print(f"test_session_df: {test_session_df.shape}")

In [None]:
holiday_df = pl.DataFrame(
    [
        {"売上日": (x.date), "name": x.name, "is_holiday": 1}
        for x in jpholiday.between(datetime.date(2024, 7, 1), datetime.date(2024, 11, 30))
    ]
)

full_session_df = raw_full_session_df.with_columns(pl.col("顧客CD").alias("original_顧客CD"))


### fe


In [None]:
@memory.cache
def get_top_co_categories(
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    n_top_per_target: int = 100,
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
) -> list[str]:
    all_top_co_categories = set()

    for target_category in target_categories:
        top_co_categories = (
            train_log_df.lazy()
            .join(jan_df.lazy().select(["JAN", "カテゴリ名"]), on="JAN", how="inner")
            .drop("JAN")
            .unique()
            .group_by("session_id")
            .agg(pl.col("カテゴリ名"))
            .filter(pl.col("カテゴリ名").list.contains(target_category))
            .explode("カテゴリ名")
            .unique()
            .group_by("カテゴリ名")
            .agg(pl.len().alias("n_cooccurrence"))
            .sort("n_cooccurrence", descending=True)
            .limit(n_top_per_target)
            .collect()["カテゴリ名"]
            .to_list()
        )

        all_top_co_categories.update(top_co_categories)

    all_top_co_categories.update(target_categories)
    return sorted(list(all_top_co_categories))


@memory.cache
def create_session_duration_last_cat_df(
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    session_df: pl.DataFrame,
    prefix: str = "days_since_cat_",
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
    cat_type: str = "カテゴリ名",
) -> pl.DataFrame:
    session_master_df = session_df.select("session_id")
    # ベースデータの作成: session_id, session_datetime, カテゴリ名, 顧客CD を含む
    base_df = (
        (
            session_df.lazy()
            .join(train_log_df.lazy(), on="session_id", how="left")
            .join(jan_df.lazy().select(["JAN", cat_type]), on="JAN", how="left")
            .select(["session_id", "session_datetime", "売上数量", cat_type, "顧客CD"])
            .unique()
        )
        .group_by(["session_id", "session_datetime", "顧客CD"])  # 顧客CDごとにグループ化
        .agg(pl.col(cat_type))
        .with_columns(
            pl.col(cat_type).list.contains(target_category).alias(target_category)
            for target_category in target_categories
        )
        .sort(["顧客CD", "session_datetime"])  # 顧客CDごとに時系列でソート
    )

    for target_category in tqdm(target_categories):
        session_duration_df = (
            (
                base_df.with_columns(
                    # 対象カテゴリを購入した場合のみ日時を記録
                    [
                        pl.when(pl.col(target_category))
                        .then(pl.col("session_datetime"))
                        .otherwise(None)
                        .alias("true_datetime")
                    ]
                )
                # 顧客CDごとに、直前までで最新のtrue_datetimeを取得
                .with_columns(
                    [
                        pl.col("true_datetime")
                        .forward_fill()
                        .shift()
                        .over("顧客CD")  # 顧客CDごとのウィンドウで処理
                        .alias("last_true_datetime")
                    ]
                )
                # 直前の対象カテゴリ購入からの経過日数を計算
                .with_columns(
                    [
                        pl.when(pl.col("last_true_datetime").is_not_null())
                        .then((pl.col("session_datetime") - pl.col("last_true_datetime")) / pl.duration(days=1))
                        .otherwise(None)
                        .alias(f"{prefix}{target_category}")
                    ]
                )
            )
            .select(
                pl.col("session_id"),
                pl.col(f"{prefix}{target_category}"),
            )
            .collect()
        )
        session_master_df = session_master_df.join(session_duration_df, on="session_id", how="left")

    return session_master_df


@memory.cache
def create_window_agg_cat_df(
    session_df: pl.DataFrame,
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
    window_size: str | int = "3mo",
    agg_method: str = "mean",  # [mean, sum, max, diff, shift]
    value_type: str = "売上有無",  # [売上有無, 売上数量, 売上金額, 値割金額, 値割数量]
    cat_type: str = "カテゴリ名",  # [カテゴリ名, 部門]
    prefix: str = "window_agg_cat_",
) -> pl.DataFrame:
    base_df = (
        session_df.lazy()
        .join(train_log_df.lazy().with_columns(pl.lit(1).alias("売上有無")), on="session_id", how="left")
        .join(
            jan_df.filter(pl.col(cat_type).is_in(target_categories)).lazy().select(["JAN", cat_type]),
            on="JAN",
            how="left",
        )
        .select(["session_id", "session_datetime", cat_type, "顧客CD", value_type])
        .unique()
        .collect()
        .pivot(
            cat_type,
            index=["顧客CD", "session_id", "session_datetime"],
            values=value_type,
            aggregate_function="sum",
        )
        .fill_null(0)
        .select([pl.col("session_id"), pl.col("session_datetime"), pl.col("顧客CD")] + target_categories)
    ).lazy()

    if agg_method == "mean":
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_mean_by("session_datetime", window_size=window_size, closed="left")
                .over("顧客CD")
                .alias(f"{prefix}_mean_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )
    elif agg_method == "sum":
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_sum_by("session_datetime", window_size=window_size, closed="left")
                .over("顧客CD")
                .alias(f"{prefix}_sum_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )
    elif agg_method == "max":
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_max_by("session_datetime", window_size=window_size, closed="left")
                .over("顧客CD")
                .alias(f"{prefix}_max_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )
    elif agg_method == "std":
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_std_by("session_datetime", window_size=window_size, closed="left")
                .over("顧客CD")
                .alias(f"{prefix}_std_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )

    elif agg_method == "diff":
        assert isinstance(window_size, int), "window_size must be int"
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .diff(n=window_size)
                .over("顧客CD")
                .alias(f"{prefix}_diff_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )
    elif agg_method == "shift":
        assert isinstance(window_size, int), "window_size must be int"
        agg_df = base_df.sort(["顧客CD", "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .shift(n=window_size)
                .over("顧客CD")
                .alias(f"{prefix}_shift_{target}_{value_type}_{window_size}")
                for target in target_categories
            ],
        )
    else:
        raise ValueError(f"Invalid agg_method: {agg_method}")

    return agg_df.collect()


@memory.cache
def craete_target_cat_df(
    session_df: pl.DataFrame,
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
    value_type: str = "売上有無",  # [売上有無, 売上数量, 売上金額, 値割金額, 値割数量]
    cat_type: str = "カテゴリ名",  # [カテゴリ名, 部門]
    prefix: str = "target_cat_",
) -> pl.DataFrame:
    df = (
        session_df.join(train_log_df.with_columns(pl.lit(1).alias("売上有無")), on="session_id", how="left")
        .join(
            jan_df.filter(pl.col(cat_type).is_in(target_categories)).select(["JAN", cat_type]),
            on="JAN",
            how="left",
        )
        .select(["session_id", "session_datetime", cat_type, "顧客CD", value_type])
        .unique()
        .pivot(
            cat_type,
            index=["顧客CD", "session_id", "session_datetime"],
            values=value_type,
            aggregate_function="sum",
        )
        .fill_null(0)
        .select(
            [pl.col("session_id"), pl.col("session_datetime"), pl.col("顧客CD")]
            + [pl.col(target).alias(f"{prefix}_{target}") for target in target_categories]
        )
    )
    return df


@memory.cache
def create_discount_ratio_master_df(
    session_df: pl.DataFrame,
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    target_categories: list[str] = ["ヘアケア", "チョコレート", "ビール", "米（5㎏以下）"],  # noqa
    cat_type: str = "カテゴリ名",
    prefix: str = "discount_ratio_",
) -> pl.DataFrame:
    df = (
        session_df.select("session_id", "売上日")
        .join(train_log_df, on="session_id", how="inner")
        .join(jan_df, on="JAN", how="inner")
        # discount ratio らしきもの (正しいかは不明)
        .filter((pl.col("値割数量") > 0) & (pl.col("値割金額") < pl.col("売上金額")))
        .filter(pl.col(cat_type).is_in(target_categories))
        .group_by(["売上日", cat_type])
        .agg(pl.col("値割金額").sum(), pl.col("売上金額").sum(), pl.col("値割数量").sum(), pl.col("売上数量").sum())
        .with_columns(
            ((pl.col("値割金額") / pl.col("値割数量")) / (pl.col("売上金額") / pl.col("売上数量"))).alias(
                "discount_ratio"
            )
        )
        .pivot(cat_type, index="売上日", values="discount_ratio")
        .select([pl.col("売上日")] + [pl.col(t).alias(f"{prefix}_{t}") for t in target_categories])
    )
    return df

In [None]:
# target_categories = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"]
target_categories = get_top_co_categories(
    train_log_df,
    jan_df,
    n_top_per_target=32,
    target_categories=["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],
)

discount_ratio_master_df = create_discount_ratio_master_df(
    session_df=full_session_df,
    train_log_df=train_log_df,
    jan_df=jan_df,
    target_categories=target_categories,
)

session_duration_last_cat_df = create_session_duration_last_cat_df(
    train_log_df=train_log_df,
    jan_df=jan_df,
    session_df=full_session_df,
    target_categories=target_categories,
    cat_type="カテゴリ名",
)


master_rolling_agg_cat_df = full_session_df.select("session_id")
for window_size, agg_method, value_type, cat_type in tqdm(
    [
        # window_size, agg_method, value_type, cat_type
        ("6mo", "sum", "売上数量", "カテゴリ名"),
        ("1mo", "sum", "売上数量", "カテゴリ名"),
        ("1w", "sum", "売上数量", "カテゴリ名"),
        ("3d", "sum", "売上数量", "カテゴリ名"),
        ("6mo", "mean", "売上数量", "カテゴリ名"),
        ("1mo", "mean", "売上数量", "カテゴリ名"),
        ("1w", "mean", "売上数量", "カテゴリ名"),
        ("3d", "mean", "売上数量", "カテゴリ名"),
        ("6mo", "max", "売上数量", "カテゴリ名"),
        ("1mo", "max", "売上数量", "カテゴリ名"),
        ("1w", "max", "売上数量", "カテゴリ名"),
        ("3d", "max", "売上数量", "カテゴリ名"),
        ("6mo", "std", "売上数量", "カテゴリ名"),
        ("1mo", "std", "売上数量", "カテゴリ名"),
        ("1w", "std", "売上数量", "カテゴリ名"),
        ("3d", "std", "売上数量", "カテゴリ名"),
    ]
):
    rolling_agg_df = create_window_agg_cat_df(
        session_df=full_session_df,
        train_log_df=train_log_df,
        jan_df=jan_df,
        target_categories=target_categories,
        window_size=window_size,
        agg_method=agg_method,
        value_type=value_type,
        cat_type=cat_type,
    )

    master_rolling_agg_cat_df = master_rolling_agg_cat_df.join(
        rolling_agg_df,
        on="session_id",
        how="left",
    )


# target_cat_df = craete_target_cat_df(
#     session_df=full_session_df,
#     train_log_df=train_log_df,
#     jan_df=jan_df,
#     value_type="売上有無",  # [売上有無, 売上数量, 売上金額, 値割金額, 値割数量]
#     cat_type="カテゴリ名",  # [カテゴリ名, 部門]
#     target_categories=target_categories,
# )


# add feature
full_session_df_with_feats = (
    (
        full_session_df.with_columns(pl.lit(1).alias("dummy"))
        .with_columns(
            pl.col("顧客CD")
            .is_in(full_session_df.filter(pl.col("dataset") != "TRAIN")["顧客CD"].unique())
            .alias("is_hot"),
            pl.col("売上日").dt.day().alias("day"),
            pl.col("売上日").dt.weekday().alias("weekday"),
            pl.col("年代")
            .replace(
                {
                    "10代以下": 10,
                    "20代": 20,
                    "30代": 30,
                    "40代": 40,
                    "50代": 50,
                    "60代": 60,
                    "70代": 70,
                    "80代以上": 80,
                    "不明": None,
                }
            )
            .cast(pl.Float32)
            .alias("age"),
            pl.col("session_id").cum_count().over("顧客CD").alias("cum_visit_count"),
            pl.col("session_datetime").diff().over("顧客CD").alias("days_since_last_visit").dt.total_days(),
            pl.col("dummy")
            .rolling_sum_by("session_datetime", window_size="1mo")
            .over("顧客CD")
            .alias("visit_count_1mo"),
            pl.col("dummy").rolling_sum_by("session_datetime", window_size="1w").over("顧客CD").alias("visit_count_1w"),
            pl.col("dummy").sum().over(["顧客CD", "売上日"]).alias("visit_count_today"),
            pl.col("dummy").sum().over(["顧客CD"]).alias("visit_count_total"),
            pl.when(pl.col("時刻").is_between(8, 12, closed="both")).then(0).otherwise(1).alias("is_am"),
        )
        .drop("dummy")
    )
    .join(holiday_df, on="売上日", how="left")
    .with_columns(pl.col("is_holiday").fill_null(0))
    .join(session_duration_last_cat_df, on="session_id", how="left")
    .join(master_rolling_agg_cat_df, on="session_id", how="left")
    .join(discount_ratio_master_df, on="売上日", how="left")
    .with_columns(pl.col(x).fill_null(0) for x in discount_ratio_master_df.columns if x != "売上日")
    # .join(target_cat_df, on="session_id", how="left")
)


In [None]:
def fe(
    train_df: pl.DataFrame,
    test_df: pl.DataFrame,
    valid_df: pl.DataFrame | None = None,
) -> tuple[pl.DataFrame, pl.DataFrame]:
    encoders = [
        RawEncoder(
            columns=META_COLS,
            prefix="",
        ),
        RawEncoder(
            columns=[
                "時刻",
                "day",
                "weekday",
                "age",
                "cum_visit_count",
                "days_since_last_visit",
                "visit_count_1mo",
                "visit_count_1w",
                "visit_count_today",
                "visit_count_total",
                "is_hot",
                "is_holiday",
                # "is_am",
                *[x for x in train_df.columns if x.startswith("days_since_cat_")],
                *[x for x in train_df.columns if x.startswith("window_agg_cat_")],
                *[x for x in train_df.columns if x.startswith("discount_ratio_")],
            ],
            prefix=NUMERICAL_FEATURE_PREFIX,
        ),
        OrdinalEncoder(
            columns=[
                "性別",
                "顧客CD",
                "店舗名",
            ],
            prefix=CATEGORICAL_FEATURE_PREFIX,
        ),
    ]

    # train, test に transform
    train_feature_df = pl.concat(
        [encoder.fit_transform(train_df) for encoder in encoders],
        how="horizontal",
    ).select(pl.all().shrink_dtype())

    test_feature_df = pl.concat(
        [encoder.transform(test_df) for encoder in encoders],
        how="horizontal",
    ).select(pl.all().shrink_dtype())

    if valid_df is not None:
        valid_feature_df = pl.concat(
            [encoder.transform(valid_df) for encoder in encoders],
            how="horizontal",
        ).select(pl.all().shrink_dtype())
        return train_feature_df, test_feature_df, valid_feature_df

    return train_feature_df, test_feature_df


train_df = full_session_df_with_feats.filter(pl.col("dataset") == "TRAIN")
valid_df = full_session_df_with_feats.filter(pl.col("dataset") == "VALID")
test_df = full_session_df_with_feats.filter(pl.col("dataset") == "TEST")

train_feature_df, test_feature_df, valid_feature_df = fe(train_df, test_df, valid_df=valid_df)

cat_feature_cols = [x for x in train_feature_df.columns if x.startswith(CATEGORICAL_FEATURE_PREFIX)]
num_feature_cols = [x for x in train_feature_df.columns if x.startswith(NUMERICAL_FEATURE_PREFIX)]
feature_cols = cat_feature_cols + num_feature_cols

### train


In [None]:
class EvalFn:
    def __init__(self, target_col: str, pred_col: str = "pred"):
        self.target_col = target_col

    def __call__(self, input_df: pl.DataFrame) -> dict[str, float]:
        y_true = input_df[self.target_col].to_numpy()
        y_pred = input_df["pred"].to_numpy()

        scores = {
            "rocauc": roc_auc_score(y_true, y_pred),
        }
        return scores

    @property
    def __name__(self) -> str:
        return self.__class__.__name__


def add_kfold(
    input_df: pl.DataFrame,
    n_splits: int,
    random_state: int,
    fold_col: str,
) -> pl.DataFrame:
    if n_splits == 1:
        return input_df.with_columns(pl.lit(0).alias(fold_col))

    skf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)  # NOTE:gkf にするべき?
    folds = np.zeros(len(input_df), dtype=np.int32)
    for fold, (_, valid_idx) in enumerate(skf.split(X=input_df)):
        folds[valid_idx] = fold
    return input_df.with_columns(pl.Series(name=fold_col, values=folds))

In [None]:
results, scores = {}, {}

for target_col in TARGET_COLS:
    va_result_df, va_scores = pl.DataFrame(), {}
    all_models = []
    for seed in SEEDS:
        print(f"# ================== seed={seed}: {target_col} ================== #")
        name = f"seed{seed}_lgb_{target_col}"
        _va_result_df, _va_scores, trained_models = single_train_fn(
            model=LightGBMWapper(
                name=name,
                model=lgb.LGBMModel(
                    objective="binary",
                    boosting="gbdt",
                    n_estimators=5000,
                    learning_rate=0.1,
                    num_leaves=31,
                    colsample_bytree=0.1,
                    subsample=0.1,
                    importance_type="gain",
                    random_state=seed,
                    force_col_wise=True,
                    class_weight="balanced",
                    # verbose=-1,
                ),
                fit_params={
                    "callbacks": [
                        lgb.early_stopping(100, first_metric_only=True),
                        lgb.log_evaluation(period=100),
                    ],
                    "categorical_feature": cat_feature_cols,
                    "feature_name": feature_cols,
                    "eval_metric": "auc",
                },
            ),
            features_df=add_kfold(
                train_feature_df,
                n_splits=N_SPLITS,
                random_state=seed,
                fold_col=FOLD_COL,
            ),
            feature_cols=feature_cols,
            target_col=target_col,
            fold_col=FOLD_COL,
            meta_cols=META_COLS + [FOLD_COL],
            out_dir=EXP_OUTPUT_DIR,
            eval_fn=EvalFn(target_col=target_col),
            overwrite=True,
            val_features_df=valid_feature_df,
            full_training=True,
        )
        va_result_df = pl.concat([va_result_df, _va_result_df], how="diagonal_relaxed")
        va_scores[name] = _va_scores
        all_models.extend(trained_models)

    va_result_agg_df = (
        va_result_df.group_by("session_id")
        .agg(pl.col("pred").mean())
        .sort("session_id")
        .join(va_result_df.select(META_COLS), on="session_id", how="left")
    )
    results[target_col] = {
        "result_df": va_result_agg_df,
        "models": all_models,
    }
    scores[target_col] = va_scores


In [None]:
def construct_va_result_df(results: dict) -> pl.DataFrame:
    result_dfs = {k: x["result_df"].with_columns(pl.col("pred").alias(f"pred_{k}")) for k, x in results.items()}

    va_result_df = pl.DataFrame()
    for i, (name, result_df) in enumerate(result_dfs.items()):
        if i == 0:
            i_df = result_df.select(["session_id", "顧客CD", f"pred_{name}", name])
        else:
            i_df = result_df.select([f"pred_{name}", name])
        va_result_df = pl.concat([va_result_df, i_df], how="horizontal")
    return va_result_df


va_result_df = construct_va_result_df(results)
pred_cols = [f"pred_{x}" for x in TARGET_COLS]
score = roc_auc_score(va_result_df[TARGET_COLS].to_numpy(), va_result_df[pred_cols].to_numpy(), average="macro")
scores["final_metric"] = score

va_result_df.write_parquet(EXP_OUTPUT_DIR / "va_result_df.parquet")
with open(EXP_OUTPUT_DIR / "scores.json", "w") as f:
    json.dump(scores, f, indent=4, ensure_ascii=False)

print(json.dumps(scores, indent=4, ensure_ascii=False))

### inference


In [None]:
te_result_df = pl.DataFrame()
for i, (name, res) in enumerate(results.items()):
    target_name = name
    if i == 0:
        cols = [pl.col("session_id"), pl.col("pred").alias(target_name)]
    else:
        cols = [pl.col("pred").alias(target_name)]

    _te_result_df = single_inference_fn_v2(
        models=res["models"],
        features_df=test_feature_df,
        feature_names=feature_cols,
    ).select(cols)
    te_result_df = pl.concat([te_result_df, _te_result_df], how="horizontal")

submission_df = (
    test_session_df.select("session_id")
    .join(te_result_df, on="session_id", how="inner")
    .select(["session_id"] + TARGET_COLS)
)

submission_df.write_parquet(EXP_OUTPUT_DIR / "te_result_df.parquet")
submission_df.select(
    [
        "チョコレート",
        "ビール",
        "ヘアケア",
        "米（5㎏以下）",
    ]
).write_csv(EXP_OUTPUT_DIR / "submission.csv")