In [None]:
import datetime
import json
import logging

import lightgbm as lgb
import numpy as np
import polars as pl
import rootutils
import scipy.sparse as sp
import seaborn as sns
from jpholiday import JPHoliday
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

sns.set_style("whitegrid")
logging.basicConfig(level=logging.INFO)

ROOT = rootutils.setup_root(".", pythonpath=True, cwd=True)

from src.feature.tabular import AggregateEncoder, OrdinalEncoder, RawEncoder
from src.feature.utils import cache
from src.model.sklearn_like import LightGBMWapper
from src.trainer.tabular.simple import single_inference_fn_v2, single_train_fn

pl.Config.set_fmt_str_lengths(200)
pl.Config.set_tbl_cols(50)
pl.Config.set_tbl_rows(50)

jpholiday = JPHoliday()


In [None]:
EXP_NAME = "032"

USE_FE_CACHE = True

SEED = 42
SEEDS = [x + SEED for x in range(3)]  # seed averaging
N_SPLITS = 1  # (N_SPLITS = 1): holdout

DEBUG = False
VALID_SAMPLE_FRAC = 1
VALID_DATE = "2024-10-01"


In [None]:
DATA_DIR = ROOT / "data"
INPUT_DIR = DATA_DIR / "atmacup19_dataset"
OUTPUT_DIR = DATA_DIR / "output"
CACHE_DIR = DATA_DIR / "cache" / EXP_NAME

for d in [DATA_DIR, INPUT_DIR, OUTPUT_DIR, CACHE_DIR]:
    d.mkdir(exist_ok=True, parents=True)

EXP_OUTPUT_DIR = OUTPUT_DIR / EXP_NAME
EXP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

FEATURE_PREFIX = "f_"
NUMERICAL_FEATURE_PREFIX = f"{FEATURE_PREFIX}n_"
CATEGORICAL_FEATURE_PREFIX = f"{FEATURE_PREFIX}c_"

TARGET_COLS = [
    "ヘアケア",
    "チョコレート",
    "ビール",
    "米（5㎏以下）",
]
META_COLS = [
    "session_id",
    "顧客CD",
    "session_datetime",
    "original_顧客CD",
] + TARGET_COLS

FOLD_COL = "fold"


### load data


In [None]:
ec_log_df = pl.read_csv(INPUT_DIR / "ec_log.csv", infer_schema_length=200000)
jan_df = pl.read_csv(INPUT_DIR / "jan.csv")
test_session_df = pl.read_csv(INPUT_DIR / "test_session.csv")
train_session_df = pl.read_csv(INPUT_DIR / "train_session.csv")
train_log_df = pl.read_csv(INPUT_DIR / "train_log.csv")

# create target
train_session_df = (
    train_session_df.select("session_id")
    .join(train_log_df, on="session_id", how="left")
    .join(jan_df.filter(pl.col("カテゴリ名").is_in(TARGET_COLS)), on="JAN", how="left")
    .select(["session_id", "カテゴリ名", "売上数量"])
    .group_by(["session_id", "カテゴリ名"])
    .agg(pl.col("売上数量").sum())
    .filter(pl.col("売上数量") > 0)
    .with_columns(pl.lit(1).alias("target"))
    .pivot("カテゴリ名", index="session_id", values="target")
    .select(["session_id"] + TARGET_COLS)
    .join(train_session_df, on="session_id", how="right")
    .with_columns(pl.col(x).fill_null(0) for x in TARGET_COLS)
)


valid_session_id = (
    train_session_df.sort("売上日")
    .group_by("顧客CD")
    .tail(1)
    .filter(pl.col("売上日") >= VALID_DATE)["session_id"]
    .unique()
    .sample(fraction=VALID_SAMPLE_FRAC, seed=SEED)
)
valid_session_df = train_session_df.filter(pl.col("session_id").is_in(valid_session_id))
train_session_df = train_session_df.filter(pl.col("session_id").is_in(valid_session_id).not_())

# sampling
if DEBUG:
    train_session_df = train_session_df.sample(10000, seed=SEED)


raw_full_session_df = (
    pl.concat(
        [
            train_session_df.with_columns(dataset=pl.lit("TRAIN")),
            valid_session_df.with_columns(dataset=pl.lit("VALID")),
            test_session_df.with_columns(dataset=pl.lit("TEST")),
        ],
        how="diagonal_relaxed",
    )
    .with_columns(pl.col("売上日").cast(pl.Date))
    .with_columns(
        pl.datetime(
            pl.col("売上日").dt.year(), pl.col("売上日").dt.month(), pl.col("売上日").dt.day(), pl.col("時刻")
        ).alias("session_datetime")
    )
)

print(f"train_session_df: {train_session_df.shape}")
print(f"valid_session_df: {valid_session_df.shape}")
print(f"test_session_df: {test_session_df.shape}")

In [None]:
holiday_df = pl.DataFrame(
    [
        {"売上日": (x.date), "name": x.name, "is_holiday": 1}
        for x in jpholiday.between(datetime.date(2024, 7, 1), datetime.date(2024, 11, 30))
    ]
)

full_session_df = raw_full_session_df.with_columns(pl.col("顧客CD").alias("original_顧客CD"))


### fe


In [None]:
@cache(cache_dir=CACHE_DIR, overwrite=not USE_FE_CACHE)
def get_top_co_categories(
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    n_top_per_target: int = 100,
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
) -> list[str]:
    print("🚀 get_top_co_categories")
    all_top_co_categories = set()

    for target_category in target_categories:
        top_co_categories = (
            train_log_df.lazy()
            .join(jan_df.lazy().select(["JAN", "カテゴリ名"]), on="JAN", how="inner")
            .drop("JAN")
            .unique()
            .group_by("session_id")
            .agg(pl.col("カテゴリ名"))
            .filter(pl.col("カテゴリ名").list.contains(target_category))
            .explode("カテゴリ名")
            .unique()
            .group_by("カテゴリ名")
            .agg(pl.len().alias("n_cooccurrence"))
            .sort("n_cooccurrence", descending=True)
            .limit(n_top_per_target)
            .collect()["カテゴリ名"]
            .to_list()
        )

        all_top_co_categories.update(top_co_categories)

    all_top_co_categories.update(target_categories)
    return sorted(list(all_top_co_categories))


@cache(cache_dir=CACHE_DIR, overwrite=not USE_FE_CACHE)
def create_session_duration_last_cat_df(
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    session_df: pl.DataFrame,
    prefix: str = "days_since_cat_",
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
    cat_type: str = "カテゴリ名",
    group_by: str | list[str] = "顧客CD",
) -> pl.DataFrame:
    print("🚀 create_session_duration_last_cat_df")
    print(f"# target_categories: {len(target_categories)}")
    group_by = [group_by] if isinstance(group_by, str) else group_by
    result_df = (
        (
            (
                session_df.lazy()
                .join(train_log_df.lazy(), on="session_id", how="left")
                .join(jan_df.lazy().select(["JAN", cat_type]), on="JAN", how="left")
                .select(["session_id", "session_datetime", "売上数量", cat_type, *group_by])
                .unique()
            )
            .group_by(["session_id", "session_datetime", *group_by])  # 顧客CDごとにグループ化
            .agg(pl.col(cat_type))
            .with_columns(
                pl.col(cat_type).list.contains(target_category).alias(target_category)
                for target_category in target_categories
            )
            .sort(
                [*group_by, "session_datetime", "session_id"]
            )  # 顧客CDごとに時系列でソート: 同一時間/別 session_id をどうするか問題
        )
        .with_columns(
            # 対象カテゴリを購入した場合のみ日時を記録
            [
                pl.when(pl.col(target_category))
                .then(pl.col("session_datetime"))
                .otherwise(None)
                .alias(f"true_datetime_{target_category}")
                for target_category in target_categories
            ]
        )
        .with_columns(
            [
                pl.col(f"true_datetime_{target_category}")
                .forward_fill()
                .shift()
                .over(group_by)  # 顧客CDごとのウィンドウで処理
                .alias(f"last_true_datetime_{target_category}")
                for target_category in target_categories
            ]
        )
        .with_columns(
            [
                pl.when(pl.col(f"last_true_datetime_{target_category}").is_not_null())
                .then(
                    (pl.col("session_datetime") - pl.col(f"last_true_datetime_{target_category}")) / pl.duration(days=1)
                )
                .otherwise(None)
                .alias(f"{prefix}_{target_category}")
                for target_category in target_categories
            ]
        )
        .select(
            pl.col("session_id"),
            *[pl.col(f"{prefix}_{target_category}") for target_category in target_categories],
        )
    )

    return result_df.collect()


@cache(cache_dir=CACHE_DIR, overwrite=not USE_FE_CACHE)
def create_window_agg_cat_df(
    session_df: pl.DataFrame,
    train_log_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    target_categories: list[str] = ["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],  # noqa
    window_size: str | int = "3mo",
    agg_method: str = "mean",  # [mean, sum, max, diff, shift]
    value_type: str = "売上有無",  # [売上有無, 売上数量, 売上金額, 値割金額, 値割数量]
    cat_type: str = "カテゴリ名",  # [カテゴリ名, 部門]
    prefix: str = "window_agg_cat_",
    group_by: str | list[str] = "顧客CD",
) -> pl.DataFrame:
    group_by = [group_by] if isinstance(group_by, str) else group_by
    jan_cols = ["JAN"] if cat_type == "JAN" else ["JAN", cat_type]
    base_df = (
        session_df.lazy()
        .join(train_log_df.lazy().with_columns(pl.lit(1).alias("売上有無")), on="session_id", how="left")
        .join(
            jan_df.filter(pl.col(cat_type).is_in(target_categories)).lazy().select(jan_cols),
            on="JAN",
            how="left",
        )
        .select(["session_id", "session_datetime", cat_type, *group_by, value_type])
        .unique()
        .collect()
        .pivot(
            cat_type,
            index=[*group_by, "session_id", "session_datetime"],
            values=value_type,
            aggregate_function="sum",
        )
        .fill_null(0)
    )
    available_categories = [x for x in target_categories if x in base_df.columns]
    base_df = base_df.select(
        [pl.col("session_id"), pl.col("session_datetime")] + group_by + available_categories
    ).lazy()

    if agg_method == "mean":
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_mean_by("session_datetime", window_size=window_size, closed="left")
                .over(group_by)
                .alias(f"{prefix}_mean_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    elif agg_method == "sum":
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_sum_by("session_datetime", window_size=window_size, closed="left")
                .over(group_by)
                .alias(f"{prefix}_sum_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    elif agg_method == "max":
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_max_by("session_datetime", window_size=window_size, closed="left")
                .over(group_by)
                .alias(f"{prefix}_max_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    elif agg_method == "min":
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_min_by("session_datetime", window_size=window_size, closed="left")
                .over(group_by)
                .alias(f"{prefix}_min_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    elif agg_method == "std":
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .rolling_std_by("session_datetime", window_size=window_size, closed="left")
                .over(group_by)
                .alias(f"{prefix}_std_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )

    elif agg_method == "diff":
        assert isinstance(window_size, int), "window_size must be int"
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .diff(n=window_size)
                .over(group_by)
                .alias(f"{prefix}_diff_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    elif agg_method == "shift":
        assert isinstance(window_size, int), "window_size must be int"
        agg_df = base_df.sort([*group_by, "session_datetime"]).select(
            pl.col("session_id"),
            *[
                pl.col(target)
                .shift(n=window_size)
                .over(group_by)
                .alias(f"{prefix}_shift_{target}_{value_type}_{window_size}")
                for target in available_categories
            ],
        )
    else:
        raise ValueError(f"Invalid agg_method: {agg_method}")

    return agg_df.collect()


@cache(cache_dir=CACHE_DIR, overwrite=not USE_FE_CACHE)
def create_session_embedding_df(
    train_log_df: pl.DataFrame,
    session_df: pl.DataFrame,
    jan_df: pl.DataFrame,
    item_type: str = "カテゴリ名",
    value_type: str = "売上数量",
    dim: int = 32,
    window_size: str | int = "6mo",
    agg_method: str = "sum",
    group_by: str | list[str] = "顧客CD",
    prefix: str = "session_embedding_",
) -> pl.DataFrame:
    print("🚀 create_session_embedding_df")
    noleak_mtx_df = (
        create_window_agg_cat_df(
            session_df=session_df,
            train_log_df=train_log_df,
            jan_df=jan_df,
            target_categories=jan_df[item_type].unique().to_list(),
            window_size=window_size,
            agg_method=agg_method,
            cat_type=item_type,
            value_type=value_type,
            group_by=group_by,
            prefix="",
        )
        .fill_null(0)
        .select(pl.all().shrink_dtype())
    )

    print(f"TruncatedSVD: dim={dim}")
    svd = TruncatedSVD(n_components=dim, random_state=42)
    session_embeddings = svd.fit_transform(sp.csr_matrix(noleak_mtx_df.drop("session_id").to_numpy()))
    embedding_df = pl.DataFrame(
        session_embeddings,
        schema=[
            f"{prefix}_{item_type}_{value_type}_{window_size}_{agg_method}_d{dim}_{i + 1:03}"
            for i in range(session_embeddings.shape[1])
        ],
    ).with_columns(pl.Series(name="session_id", values=noleak_mtx_df["session_id"]))
    return embedding_df

In [None]:
session_duration_last_cat_df = create_session_duration_last_cat_df(
    train_log_df=train_log_df,
    jan_df=jan_df,
    session_df=full_session_df,
    target_categories=get_top_co_categories(
        train_log_df,
        jan_df,
        n_top_per_target=128,
        target_categories=["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],
    ),
    cat_type="カテゴリ名",
)

master_session_embedding_df = full_session_df.select("session_id")
for window_size, agg_method, value_type in tqdm(
    [
        # window_size, agg_method, value_type
        ("6mo", "sum", "売上数量"),
        # ("1mo", "sum", "売上数量"),
        ("1w", "sum", "売上数量"),
        ("3d", "sum", "売上数量"),
        # ("6mo", "max", "売上数量"),
        # ("1mo", "max", "売上数量"),
        # ("1w", "max", "売上数量"),
        # ("3d", "max", "売上数量"),
        # ("6mo", "min", "売上数量"),
        # ("1mo", "min", "売上数量"),
        # ("1w", "min", "売上数量"),
        # ("3d", "min", "売上数量"),
    ]
):
    master_session_embedding_df = master_session_embedding_df.join(
        create_session_embedding_df(
            train_log_df=train_log_df,
            session_df=full_session_df,
            jan_df=jan_df,
            item_type="カテゴリ名",
            value_type=value_type,
            dim=128,
            window_size=window_size,
            agg_method=agg_method,
            prefix="session_embedding_v1_",
        ),
        on="session_id",
        how="left",
    )

    master_session_embedding_df = master_session_embedding_df.join(
        create_session_embedding_df(
            train_log_df=train_log_df,
            session_df=full_session_df,
            jan_df=jan_df,
            item_type="カテゴリ名",
            value_type=value_type,
            dim=128,
            window_size=window_size,
            agg_method=agg_method,
            group_by=["年代", "性別", "店舗名"],
            prefix="session_embedding_v2_",
        ),
        on="session_id",
        how="left",
    )

master_rolling_agg_cat_df = full_session_df.select("session_id")
for window_size, agg_method, value_type in tqdm(
    [
        # window_size, agg_method, value_type
        ("6mo", "sum", "売上数量"),
        # ("1mo", "sum", "売上数量"),
        ("1w", "sum", "売上数量"),
        ("3d", "sum", "売上数量"),
        ("6mo", "mean", "売上数量"),
        # ("1mo", "mean", "売上数量"),
        ("1w", "mean", "売上数量"),
        ("3d", "mean", "売上数量"),
        ("6mo", "max", "売上数量"),
        # ("1mo", "max", "売上数量"),
        ("1w", "max", "売上数量"),
        ("3d", "max", "売上数量"),
        ("6mo", "std", "売上数量"),
        # ("1mo", "std", "売上数量"),
        ("1w", "std", "売上数量"),
        ("3d", "std", "売上数量"),
        ("6mo", "max", "値割数量"),
        ("6mo", "sum", "値割数量"),
        ("6mo", "mean", "値割数量"),
        ("6mo", "std", "値割数量"),
    ]
):
    rolling_agg_df = create_window_agg_cat_df(
        session_df=full_session_df,
        train_log_df=train_log_df,
        jan_df=jan_df,
        target_categories=get_top_co_categories(
            train_log_df,
            jan_df,
            n_top_per_target=16,
            target_categories=["チョコレート", "ビール", "米（5㎏以下）", "ヘアケア"],
        ),
        window_size=window_size,
        agg_method=agg_method,
        value_type=value_type,
        cat_type="カテゴリ名",
    )

    master_rolling_agg_cat_df = master_rolling_agg_cat_df.join(
        rolling_agg_df,
        on="session_id",
        how="left",
    )


# add feature
full_session_df_with_feats = (
    (
        full_session_df.with_columns(pl.lit(1).alias("dummy"))
        .with_columns(
            pl.col("顧客CD")
            .is_in(full_session_df.filter(pl.col("dataset") != "TRAIN")["顧客CD"].unique())
            .alias("is_hot"),
            pl.col("売上日").dt.day().alias("day"),
            pl.col("売上日").dt.weekday().alias("weekday"),
            pl.col("年代")
            .replace(
                {
                    "10代以下": 10,
                    "20代": 20,
                    "30代": 30,
                    "40代": 40,
                    "50代": 50,
                    "60代": 60,
                    "70代": 70,
                    "80代以上": 80,
                    "不明": None,
                }
            )
            .cast(pl.Float32)
            .alias("age"),
            pl.col("session_id").cum_count().over("顧客CD").alias("cum_visit_count"),
            pl.col("session_datetime").diff().over("顧客CD").alias("days_since_last_visit").dt.total_days(),
            pl.col("dummy")
            .rolling_sum_by("session_datetime", window_size="1mo")
            .over("顧客CD")
            .alias("visit_count_1mo"),
            pl.col("dummy").rolling_sum_by("session_datetime", window_size="1w").over("顧客CD").alias("visit_count_1w"),
            pl.col("dummy").sum().over(["顧客CD", "売上日"]).alias("visit_count_today"),
            pl.col("dummy").sum().over(["顧客CD"]).alias("visit_count_total"),
            pl.when(pl.col("時刻").is_between(8, 12, closed="both")).then(0).otherwise(1).alias("is_am"),
        )
        .drop("dummy")
    )
    .join(holiday_df, on="売上日", how="left")
    .with_columns(pl.col("is_holiday").fill_null(0))
    .join(session_duration_last_cat_df, on="session_id", how="left")
    .join(master_rolling_agg_cat_df, on="session_id", how="left")
    .join(master_session_embedding_df, on="session_id", how="left")
)


In [None]:
def fe(
    train_df: pl.DataFrame,
    test_df: pl.DataFrame,
    valid_df: pl.DataFrame | None = None,
) -> tuple[pl.DataFrame, pl.DataFrame]:
    encoders = [
        RawEncoder(
            columns=META_COLS,
            prefix="",
        ),
        RawEncoder(
            columns=[
                "時刻",
                "day",
                "weekday",
                "age",
                "cum_visit_count",
                "days_since_last_visit",
                "visit_count_1mo",
                "visit_count_1w",
                "visit_count_today",
                "visit_count_total",
                "is_hot",
                "is_holiday",
                # "is_am",
                *[x for x in train_df.columns if x.startswith("days_since_cat_")],
                *[x for x in train_df.columns if x.startswith("window_agg_cat_")],
                *[x for x in train_df.columns if x.startswith("session_embedding_")],
            ],
            prefix=NUMERICAL_FEATURE_PREFIX,
        ),
        OrdinalEncoder(
            columns=[
                "性別",
                "顧客CD",
                "店舗名",
            ],
            prefix=CATEGORICAL_FEATURE_PREFIX,
        ),
    ]

    # train, test に transform
    train_feature_df = pl.concat(
        [encoder.fit_transform(train_df) for encoder in encoders],
        how="horizontal",
    ).select(pl.all().shrink_dtype())

    test_feature_df = pl.concat(
        [encoder.transform(test_df) for encoder in encoders],
        how="horizontal",
    ).select(pl.all().shrink_dtype())

    if valid_df is not None:
        valid_feature_df = pl.concat(
            [encoder.transform(valid_df) for encoder in encoders],
            how="horizontal",
        ).select(pl.all().shrink_dtype())
        return train_feature_df, test_feature_df, valid_feature_df

    return train_feature_df, test_feature_df


train_df = full_session_df_with_feats.filter(pl.col("dataset") == "TRAIN")
valid_df = full_session_df_with_feats.filter(pl.col("dataset") == "VALID")
test_df = full_session_df_with_feats.filter(pl.col("dataset") == "TEST")

train_feature_df, test_feature_df, valid_feature_df = fe(train_df, test_df, valid_df=valid_df)

cat_feature_cols = [x for x in train_feature_df.columns if x.startswith(CATEGORICAL_FEATURE_PREFIX)]
num_feature_cols = [x for x in train_feature_df.columns if x.startswith(NUMERICAL_FEATURE_PREFIX)]
feature_cols = cat_feature_cols + num_feature_cols

print(f"numerical features: {len(num_feature_cols)}")
print(f"categorical features: {len(cat_feature_cols)}")

### train


In [None]:
class EvalFn:
    def __init__(self, target_col: str, pred_col: str = "pred"):
        self.target_col = target_col

    def __call__(self, input_df: pl.DataFrame) -> dict[str, float]:
        y_true = input_df[self.target_col].to_numpy()
        y_pred = input_df["pred"].to_numpy()

        scores = {
            "rocauc": roc_auc_score(y_true, y_pred),
        }
        return scores

    @property
    def __name__(self) -> str:
        return self.__class__.__name__


def add_kfold(
    input_df: pl.DataFrame,
    n_splits: int,
    random_state: int,
    fold_col: str,
) -> pl.DataFrame:
    if n_splits == 1:
        return input_df.with_columns(pl.lit(0).alias(fold_col))

    skf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)  # NOTE:gkf にするべき?
    folds = np.zeros(len(input_df), dtype=np.int32)
    for fold, (_, valid_idx) in enumerate(skf.split(X=input_df)):
        folds[valid_idx] = fold
    return input_df.with_columns(pl.Series(name=fold_col, values=folds))

In [None]:
results, scores = {}, {}

for target_col in TARGET_COLS:
    va_result_df, va_scores = pl.DataFrame(), {}
    all_models = []
    for seed in SEEDS:
        print(f"# ================== seed={seed}: {target_col} ================== #")
        name = f"seed{seed}_lgb_{target_col}"
        _va_result_df, _va_scores, trained_models = single_train_fn(
            model=LightGBMWapper(
                name=name,
                model=lgb.LGBMModel(
                    objective="binary",
                    boosting="gbdt",
                    n_estimators=10000,
                    learning_rate=0.01,
                    num_leaves=31,
                    colsample_bytree=0.1,
                    subsample=0.1,
                    importance_type="gain",
                    random_state=seed,
                    force_col_wise=True,
                    class_weight="balanced",
                    # verbose=-1,
                ),
                fit_params={
                    "callbacks": [
                        lgb.early_stopping(100, first_metric_only=True),
                        lgb.log_evaluation(period=100),
                    ],
                    "categorical_feature": cat_feature_cols,
                    "feature_name": feature_cols,
                    "eval_metric": "auc",
                },
            ),
            features_df=add_kfold(
                train_feature_df,
                n_splits=N_SPLITS,
                random_state=seed,
                fold_col=FOLD_COL,
            ),
            feature_cols=feature_cols,
            target_col=target_col,
            fold_col=FOLD_COL,
            meta_cols=META_COLS + [FOLD_COL],
            out_dir=EXP_OUTPUT_DIR,
            eval_fn=EvalFn(target_col=target_col),
            overwrite=True,
            val_features_df=valid_feature_df,
            full_training=True,
        )
        va_result_df = pl.concat([va_result_df, _va_result_df], how="diagonal_relaxed")
        va_scores[name] = _va_scores
        all_models.extend(trained_models)

    va_result_agg_df = (
        va_result_df.group_by("session_id")
        .agg(pl.col("pred").mean())
        .sort("session_id")
        .join(va_result_df.select(META_COLS), on="session_id", how="left")
    )
    results[target_col] = {
        "result_df": va_result_agg_df,
        "models": all_models,
    }
    scores[target_col] = va_scores


In [None]:
def construct_va_result_df(results: dict) -> pl.DataFrame:
    result_dfs = {k: x["result_df"].with_columns(pl.col("pred").alias(f"pred_{k}")) for k, x in results.items()}

    va_result_df = pl.DataFrame()
    for i, (name, result_df) in enumerate(result_dfs.items()):
        if i == 0:
            i_df = result_df.select(["session_id", "顧客CD", f"pred_{name}", name])
        else:
            i_df = result_df.select([f"pred_{name}", name])
        va_result_df = pl.concat([va_result_df, i_df], how="horizontal")
    return va_result_df


va_result_df = construct_va_result_df(results)
pred_cols = [f"pred_{x}" for x in TARGET_COLS]
score = roc_auc_score(va_result_df[TARGET_COLS].to_numpy(), va_result_df[pred_cols].to_numpy(), average="macro")
scores["final_metric"] = score

va_result_df.write_parquet(EXP_OUTPUT_DIR / "va_result_df.parquet")
with open(EXP_OUTPUT_DIR / "scores.json", "w") as f:
    json.dump(scores, f, indent=4, ensure_ascii=False)

print(json.dumps(scores, indent=4, ensure_ascii=False))

### inference


In [None]:
te_result_df = pl.DataFrame()
for i, (name, res) in enumerate(results.items()):
    target_name = name
    if i == 0:
        cols = [pl.col("session_id"), pl.col("pred").alias(target_name)]
    else:
        cols = [pl.col("pred").alias(target_name)]

    _te_result_df = single_inference_fn_v2(
        models=res["models"],
        features_df=test_feature_df,
        feature_names=feature_cols,
    ).select(cols)
    te_result_df = pl.concat([te_result_df, _te_result_df], how="horizontal")

submission_df = (
    test_session_df.select("session_id")
    .join(te_result_df, on="session_id", how="left")
    .select(["session_id"] + TARGET_COLS)
)

submission_df.write_parquet(EXP_OUTPUT_DIR / "te_result_df.parquet")
submission_df.select(
    [
        "チョコレート",
        "ビール",
        "ヘアケア",
        "米（5㎏以下）",
    ]
).write_csv(EXP_OUTPUT_DIR / "submission.csv")