testの出現確率を重視した確率行列を用いて、seq_len>=2 のものの後処理も行いたい

- testに出現しないものは予測から取り除く or 重みを下げる　←　単純で効果が分かりやすいのでまずPoCとしてやる
- 確率行列を用いて得られる確率値との積を考える　←　出現しないものまで消えてしまうので、PoCの結果に応じて重みを変えるなど対応を取る（1.0~0.5の範囲で変化させるようにする等。欠損は0.5にするとか）

In [1]:
%cd ..

/kaggle/working


In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../experiments/ensemble_006_filter"):
    cfg = compose(config_name="config.yaml", overrides=["exp=base"])
    print(OmegaConf.to_yaml(cfg))

debug: false
seed: 42
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  other_dirs:
    output/exp/008_split/base: 0.8
    output/exp/012_cat_boost/base: 0.2
  first_dirs:
    output/exp/008_split/v025_003_first: 1.0
  transision_path: output/cand_supervised/supervised-prob_matrix_test_weight/003/yad2yad_feature.parquet
  score_col: transition_prob



In [7]:
import logging
import os
import pickle
import sys
import time
from pathlib import Path

import hydra
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm

import utils
from utils.load import load_label_data, load_log_data, load_session_data, load_yad_data
from utils.logger import get_logger
from utils.metrics import calculate_metrics

logger = None

In [4]:
def make_eval_df(cfg, other_oof_df: pl.DataFrame, first_oof_df: pl.DataFrame):
    other_oof_df = other_oof_df.filter(pl.col("session_count") != 1).drop(
        "session_count"
    )
    first_oof_df = first_oof_df.filter(pl.col("session_count") == 1).drop(
        "session_count"
    )
    pred_df = pl.concat([other_oof_df, first_oof_df]).sort(
        by=["session_id", "pred"], descending=True
    )
    pred_candidates_df = pred_df.group_by("session_id").agg(pl.col("candidates"))
    train_label_df = load_label_data(Path(cfg.dir.data_dir))
    candidaates_df = pred_candidates_df.join(
        train_label_df, on="session_id", how="left"
    )
    return candidaates_df


def make_submission(cfg, other_test_df: pl.DataFrame, first_test_df: pl.DataFrame):
    other_test_df = other_test_df.filter(pl.col("session_count") != 1).drop(
        "session_count"
    )
    first_test_df = first_test_df.filter(pl.col("session_count") == 1).drop(
        "session_count"
    )
    pred_df = pl.concat([other_test_df, first_test_df]).sort(
        by=["session_id", "pred"], descending=True
    )
    session_df = load_session_data(Path(cfg.dir.data_dir), "test")
    pred_candidates_df = pred_df.group_by("session_id").agg(pl.col("candidates"))
    submission_df = (
        session_df.join(
            pred_candidates_df.with_columns(
                [
                    pl.col("candidates").list.get(i).alias(f"predict_{i}")
                    for i in range(10)
                ]
            ).drop("candidates"),
            on="session_id",
            how="left",
        )
        .fill_null(-1)
        .drop("session_id")
    )
    return submission_df


def concat_label_pred(cfg, first_df, transition_df, mode):
    # 最後のyad_noを作る＆そのセッションでの長さを計算&長さ１のものだけ残す
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    last_log_df = (
        log_df.with_columns(
            (pl.col("seq_no").max().over("session_id") + 1).alias("session_count")
        )
        .filter(pl.col("session_count") == 1)
        .rename({"yad_no": "from_yad_no"})
    )
    # session と結合
    session_df = load_session_data(Path(cfg.dir.data_dir), mode)
    session_last_df = session_df.join(
        last_log_df.select(["session_id", "from_yad_no", "session_count"]),
        on="session_id",
    )

    # transitionと結合
    first_df_from_label = (
        session_last_df.join(
            transition_df.rename({cfg.exp.score_col: "pred"}), on="from_yad_no"
        )
        .with_columns(
            pl.col("to_yad_no").alias("candidates").cast(pl.Int32),
            pl.col("session_count").cast(pl.Int32),
            (pl.col("pred") + 1) * 1000,
        )
        .drop(["from_yad_no", "to_yad_no"])
        .select(["session_id", "candidates", "pred", "session_count"])
    )

    # first と結合
    result = (
        pl.concat([first_df, first_df_from_label])
        .group_by(["session_id", "candidates"])
        .agg(pl.col("pred").sum(), pl.col("session_count").max())
        .sort(by=["session_id", "pred"], descending=True)
    )
    return result

In [5]:
other_oof_dfs = []
other_test_dfs = []
first_oof_dfs = []
first_test_dfs = []
for path, weight in cfg.exp.other_dirs.items():
    df = pl.read_parquet(Path(path) / "oof_pred.parquet")
    df = df.with_columns(
        pl.col("pred") * weight,
        pl.col("session_count").cast(pl.Int32),
    )
    other_oof_dfs.append(df)
    df = pl.read_parquet(Path(path) / "test_pred.parquet")
    df = df.with_columns(
        pl.col("pred") * weight,
        pl.col("session_count").cast(pl.Int32),
    )
    other_test_dfs.append(df)
for path, weight in cfg.exp.first_dirs.items():
    df = pl.read_parquet(Path(path) / "oof_pred.parquet")
    df = df.with_columns(
        pl.col("pred") * weight,
        pl.col("session_count").cast(pl.Int32),
    )
    first_oof_dfs.append(df)
    df = pl.read_parquet(Path(path) / "test_pred.parquet")
    df = df.with_columns(
        pl.col("pred") * weight,
        pl.col("session_count").cast(pl.Int32),
    )
    first_test_dfs.append(df)

In [19]:
other_oof_df = (
    pl.concat(other_oof_dfs)
    .group_by(["session_id", "candidates"])
    .agg(
        pl.col("pred").sum(),
        pl.col("session_count").max(),
    )
)

In [21]:
other_oof_df.describe()

describe,session_id,candidates,pred,session_count
str,str,f64,f64,f64
"""count""","""10843218""",10843218.0,10843218.0,10843218.0
"""null_count""","""0""",0.0,0.0,0.0
"""mean""",,6948.763338,-1.328015,1.499708
"""std""",,4024.038394,1.484091,0.731467
"""min""","""000007603d533d…",1.0,-6.379624,1.0
"""25%""",,3441.0,-2.653332,1.0
"""50%""",,6925.0,-0.503633,1.0
"""75%""",,10478.0,-0.318207,2.0
"""max""","""fffffa7baf3700…",13806.0,6.416839,10.0


In [10]:
test_count_df = (
    load_log_data(Path(cfg.dir.data_dir), "test")
    .unique(["session_id", "yad_no"])["yad_no"]
    .value_counts()
)
yad_df = load_yad_data(Path(cfg.dir.data_dir)).select(["yad_no"])

In [14]:
yad_counts_df = yad_df.join(test_count_df, on="yad_no", how="left").fill_null(0)
yad_counts_df.head()

yad_no,counts
i64,u32
1,30
2,5
3,29
4,17
5,0


In [16]:
# filter
filter_yad_list = yad_counts_df.filter(pl.col("counts") == 0)["yad_no"].to_list()

In [18]:
len(filter_yad_list)

2583

In [22]:
def filter_unseen_yad(cfg, df):
    # 出現回数をカウント
    test_count_df = (
        load_log_data(Path(cfg.dir.data_dir), "test")
        .unique(["session_id", "yad_no"])["yad_no"]
        .value_counts()
    )
    # 欠損値も考慮して全 yadの出現回数を作成
    yad_df = load_yad_data(Path(cfg.dir.data_dir)).select(["yad_no"])
    yad_counts_df = yad_df.join(test_count_df, on="yad_no", how="left").fill_null(0)

    # 出現回数0回のyadリスト作成
    filter_yad_list = yad_counts_df.filter(pl.col("counts") == 0)["yad_no"].to_list()

    # 対象のyadのpredを -100する
    df = df.with_columns(
        pl.when(pl.col("candidates").is_in(filter_yad_list))
        .then(pl.col("pred") - 100)
        .otherwise(pl.col("pred"))
        .alias("pred")
    )
    return df

In [23]:
df = filter_unseen_yad(cfg, other_oof_df)

In [24]:
df.describe()

describe,session_id,candidates,pred,session_count
str,str,f64,f64,f64
"""count""","""10843218""",10843218.0,10843218.0,10843218.0
"""null_count""","""0""",0.0,0.0,0.0
"""mean""",,6948.763338,-6.224366,1.499708
"""std""",,4024.038394,21.717207,0.731467
"""min""","""000007603d533d…",1.0,-106.213527,1.0
"""25%""",,3441.0,-2.944577,1.0
"""50%""",,6925.0,-0.534332,1.0
"""75%""",,10478.0,-0.321592,2.0
"""max""","""fffffa7baf3700…",13806.0,6.416839,10.0


In [25]:
df

session_id,candidates,pred,session_count
str,i32,f64,i32
"""fffffa7baf3700…",2439,1.662029,2
"""fffffa7baf3700…",3338,-3.436939,2
"""fffffa7baf3700…",3422,-3.939829,2
"""ffff2360540745…",10259,-0.578957,1
"""ffff2262d38abd…",4193,-0.389775,1
"""fffe8c99c5b332…",4712,-4.007725,2
"""fffe8a472ae6a9…",6257,-4.438817,3
"""fffe8a472ae6a9…",8640,-4.468189,3
"""fffe78a078a176…",5420,-0.278949,1
"""fffe78a078a176…",6611,-0.317324,1
