In [1]:
%cd ..

/kaggle/working


In [44]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../experiments/ensemble_001"):
    cfg = compose(config_name="config.yaml", overrides=["exp=base"])
    print(OmegaConf.to_yaml(cfg))

debug: false
seed: 42
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  other_dirs:
  - output/exp/008_split/base
  first_dirs:
  - output/exp/008_split/first



In [45]:
import logging
import os
import pickle
import sys
import time
from pathlib import Path

import hydra
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm

import utils
import wandb
from utils.load import load_label_data, load_log_data, load_session_data
from utils.logger import get_logger
from utils.metrics import calculate_metrics

In [64]:
other_oof_df = pl.read_parquet(Path(cfg.exp.other_dirs[0]) / "oof_pred.parquet")
other_test_df = pl.read_parquet(Path(cfg.exp.other_dirs[0]) / "test_pred.parquet")
other_oof_df.head()

session_id,candidates,pred,session_count
str,i32,f64,u32
"""fffffa7baf3700…",2439,2.098421,2
"""fffffa7baf3700…",2981,0.334312,2
"""fffffa7baf3700…",10095,-0.838374,2
"""fffffa7baf3700…",1372,-1.271482,2
"""fffffa7baf3700…",3,-1.31913,2


In [65]:
first_oof_df = pl.read_parquet(Path(cfg.exp.first_dirs[0]) / "oof_pred.parquet")
first_test_df = pl.read_parquet(Path(cfg.exp.first_dirs[0]) / "test_pred.parquet")

In [72]:
def make_eval_df(other_oof_df: pl.DataFrame, first_oof_df: pl.DataFrame):
    other_oof_df = other_oof_df.filter(pl.col("session_count") != 1).drop(
        "session_count"
    )
    first_oof_df = first_oof_df.filter(pl.col("session_count") == 1).drop(
        "session_count"
    )
    pred_df = pl.concat([other_oof_df, first_oof_df]).sort(
        by=["session_id", "pred"], descending=True
    )
    pred_candidates_df = pred_df.group_by("session_id").agg(pl.col("candidates"))
    train_label_df = load_label_data(Path(cfg.dir.data_dir))
    candidaates_df = pred_candidates_df.join(
        train_label_df, on="session_id", how="left"
    )
    return candidaates_df

In [92]:
oof_candidate_df = make_eval_df(other_oof_df, first_oof_df)
print(oof_candidate_df.head())

metrics = calculate_metrics(
    oof_candidate_df, candidates_col="candidates", label_col="yad_no", k=[10]
)
print(metrics)

shape: (5, 3)
┌──────────────────────────────────┬────────────────────────┬────────┐
│ session_id                       ┆ candidates             ┆ yad_no │
│ ---                              ┆ ---                    ┆ ---    │
│ str                              ┆ list[i32]              ┆ i64    │
╞══════════════════════════════════╪════════════════════════╪════════╡
│ bab7bf061f6b4f9b2d9e726e7e76e321 ┆ [12053, 8513, … 10940] ┆ 12053  │
│ f0dad7dcb3fc66fc8c2808e36f34525d ┆ [4881, 9333, … 8663]   ┆ 10825  │
│ c94950cdb9cf420be8f6ded879fc2579 ┆ [5222, 12444, … 12239] ┆ 12964  │
│ a69e3d37b3aabf3e50f69466a3270c30 ┆ [7315, 8953, … 11935]  ┆ 5987   │
│ 5d506c5b8b9da018f631a7d8ebaef218 ┆ [5080, 6991, … 6489]   ┆ 5080   │
└──────────────────────────────────┴────────────────────────┴────────┘
k: 10
avg_num_candidates: 9.990380951721177
recall: 0.5965195463771831
precision: 0.059651954637718316
map@k: 0.40264062808754836

[{'k': 10, 'avg_num_candidates': 9.990380951721177, 'recall': 0.5965195463

In [75]:
def make_submission(other_test_df: pl.DataFrame, first_test_df: pl.DataFrame):
    other_test_df = other_test_df.filter(pl.col("session_count") != 1).drop(
        "session_count"
    )
    first_test_df = first_test_df.filter(pl.col("session_count") == 1).drop(
        "session_count"
    )
    pred_df = pl.concat([other_test_df, first_test_df]).sort(
        by=["session_id", "pred"], descending=True
    )
    session_df = load_session_data(Path(cfg.dir.data_dir), "test")
    pred_candidates_df = pred_df.group_by("session_id").agg(pl.col("candidates"))
    submission_df = (
        session_df.join(
            pred_candidates_df.with_columns(
                [
                    pl.col("candidates").list.get(i).alias(f"predict_{i}")
                    for i in range(10)
                ]
            ).drop("candidates"),
            on="session_id",
            how="left",
        )
        .fill_null(-1)
        .drop("session_id")
    )
    return submission_df

In [76]:
test_candidate_df = make_submission(other_test_df, first_test_df)
test_candidate_df.head()

predict_0,predict_1,predict_2,predict_3,predict_4,predict_5,predict_6,predict_7,predict_8,predict_9
i32,i32,i32,i32,i32,i32,i32,i32,i32,i32
3560,11561,4545,9534,4714,4420,5466,2680,6563,6488
143,6555,4066,11923,613,7014,8108,12862,6129,11237
757,7710,9190,9910,1774,410,10485,13570,6721,3400
12341,3359,6991,1542,13521,10861,5080,4180,5657,9319
2862,9020,5372,9623,10826,9611,3854,763,3476,6161


## ルールベースでの session_count==1 の変更

In [77]:
label_pred_df.head()

yad_no,yad_no_label,pred,session_count
i64,i64,f64,i32
13806,11113,400.0,1
13806,3326,200.0,1
13806,6997,200.0,1
13806,8762,200.0,1
13806,4020,200.0,1


In [95]:
# sessionに最後に見たyad_noを付与
mode = "train"


def concat_label_pred(first_df, mode):
    # 最後のyad_noだけを残す & labelを付与
    train_log_df = load_log_data(Path(cfg.dir.data_dir), "train")
    train_label_df = load_label_data(Path(cfg.dir.data_dir))
    train_last_log_label_df = (
        train_log_df.join(train_label_df, on="session_id", suffix="_label")
        .with_columns(
            (pl.col("seq_no").max().over("session_id") + 1).alias("session_count")
        )
        .filter(pl.col("seq_no") == pl.col("session_count") - 1)
    )
    # 実績ラベルからyad_noごとに良さそうな対象を探す
    label_pred_df = (
        train_last_log_label_df.group_by(["yad_no", "yad_no_label"])
        .agg(pl.col("yad_no").count().alias("pred"))
        .with_columns(pl.col("pred") * 100.0, pl.lit(1).alias("session_count"))
        .sort(by=["yad_no", "pred", "session_count"], descending=True)
    )
    print(label_pred_df)

    # 予測値作成
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    last_log_df = log_df.with_columns(
        (pl.col("seq_no").max().over("session_id") + 1).alias("session_count")
    ).filter(pl.col("seq_no") == pl.col("session_count") - 1)
    session_df = load_session_data(Path(cfg.dir.data_dir), mode)
    session_last_df = (
        session_df.join(
            last_log_df.select(["session_id", "yad_no", "session_count"]),
            on="session_id",
        )
        .filter(pl.col("session_count") == 1)
        .drop("session_count")
    )
    first_df_from_label = (
        session_last_df.join(label_pred_df, on="yad_no")
        .with_columns(
            pl.col("yad_no_label").alias("candidates").cast(pl.Int32),
            pl.col("session_count").cast(pl.Int32),
        )
        .drop(["yad_no", "yad_no_label"])
        .group_by(["session_id", "candidates"])
        .agg(pl.col("pred").max(), pl.col("session_count").max())
        .select(["session_id", "candidates", "pred", "session_count"])
        .sort(by=["session_id", "pred"], descending=True)
    )
    # first と結合
    return pl.concat([first_df, first_df_from_label])


concat_label_pred(first_oof_df, "train")

shape: (188_071, 4)
┌────────┬──────────────┬───────┬───────────────┐
│ yad_no ┆ yad_no_label ┆ pred  ┆ session_count │
│ ---    ┆ ---          ┆ ---   ┆ ---           │
│ i64    ┆ i64          ┆ f64   ┆ i32           │
╞════════╪══════════════╪═══════╪═══════════════╡
│ 13806  ┆ 11113        ┆ 400.0 ┆ 1             │
│ 13806  ┆ 6997         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 4020         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 8762         ┆ 200.0 ┆ 1             │
│ …      ┆ …            ┆ …     ┆ …             │
│ 2      ┆ 36           ┆ 100.0 ┆ 1             │
│ 2      ┆ 12232        ┆ 100.0 ┆ 1             │
│ 2      ┆ 2200         ┆ 100.0 ┆ 1             │
│ 2      ┆ 9382         ┆ 100.0 ┆ 1             │
└────────┴──────────────┴───────┴───────────────┘


session_id,candidates,pred,session_count
str,i32,f64,i32
"""fffffa7baf3700…",2439,0.245733,2
"""fffffa7baf3700…",1372,0.169726,2
"""fffffa7baf3700…",10095,0.169726,2
"""fffffa7baf3700…",12154,0.141801,2
"""fffffa7baf3700…",3,0.127825,2
"""fffffa7baf3700…",9624,0.098462,2
"""fffffa7baf3700…",10439,0.098462,2
"""fffffa7baf3700…",10415,0.071892,2
"""fffffa7baf3700…",5294,0.071892,2
"""fffffa7baf3700…",6579,0.071892,2


In [96]:
oof_candidate_df = make_eval_df(other_oof_df, concat_label_pred(first_oof_df, "train"))
print(oof_candidate_df.head())

metrics = calculate_metrics(
    oof_candidate_df, candidates_col="candidates", label_col="yad_no", k=[10]
)
print(metrics)

shape: (188_071, 4)
┌────────┬──────────────┬───────┬───────────────┐
│ yad_no ┆ yad_no_label ┆ pred  ┆ session_count │
│ ---    ┆ ---          ┆ ---   ┆ ---           │
│ i64    ┆ i64          ┆ f64   ┆ i32           │
╞════════╪══════════════╪═══════╪═══════════════╡
│ 13806  ┆ 11113        ┆ 400.0 ┆ 1             │
│ 13806  ┆ 6997         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 8762         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 4020         ┆ 200.0 ┆ 1             │
│ …      ┆ …            ┆ …     ┆ …             │
│ 2      ┆ 1099         ┆ 100.0 ┆ 1             │
│ 2      ┆ 217          ┆ 100.0 ┆ 1             │
│ 2      ┆ 9840         ┆ 100.0 ┆ 1             │
│ 2      ┆ 11562        ┆ 100.0 ┆ 1             │
└────────┴──────────────┴───────┴───────────────┘
shape: (5, 3)
┌──────────────────────────────────┬───────────────────────┬────────┐
│ session_id                       ┆ candidates            ┆ yad_no │
│ ---                              ┆ ---                   ┆ ---    │
│ str 

In [97]:
test_candidate_df = make_submission(
    other_test_df, concat_label_pred(first_test_df, "test")
)
test_candidate_df.head()

shape: (188_071, 4)
┌────────┬──────────────┬───────┬───────────────┐
│ yad_no ┆ yad_no_label ┆ pred  ┆ session_count │
│ ---    ┆ ---          ┆ ---   ┆ ---           │
│ i64    ┆ i64          ┆ f64   ┆ i32           │
╞════════╪══════════════╪═══════╪═══════════════╡
│ 13806  ┆ 11113        ┆ 400.0 ┆ 1             │
│ 13806  ┆ 4020         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 3326         ┆ 200.0 ┆ 1             │
│ 13806  ┆ 6997         ┆ 200.0 ┆ 1             │
│ …      ┆ …            ┆ …     ┆ …             │
│ 2      ┆ 9034         ┆ 100.0 ┆ 1             │
│ 2      ┆ 36           ┆ 100.0 ┆ 1             │
│ 2      ┆ 12232        ┆ 100.0 ┆ 1             │
│ 2      ┆ 8187         ┆ 100.0 ┆ 1             │
└────────┴──────────────┴───────┴───────────────┘


predict_0,predict_1,predict_2,predict_3,predict_4,predict_5,predict_6,predict_7,predict_8,predict_9
i32,i32,i32,i32,i32,i32,i32,i32,i32,i32
3560,11561,4545,9534,4714,4420,5466,2680,6563,6488
143,4066,7014,7913,6555,12862,6129,11237,8108,613
757,7710,9190,9910,1774,410,10485,13570,6721,3400
12341,3359,6991,1542,13521,10861,5080,4180,5657,9319
9020,3844,6161,4070,13235,9623,3854,12029,6565,3476
