In [1]:
%cd ..

/kaggle/working


In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../generate_datasets/005"):
    cfg = compose(config_name="config.yaml", overrides=["debug=True"])
    print(OmegaConf.to_yaml(cfg))

debug: true
seed: 7
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  fold_path: /kaggle/working/output/datasets/make_cv/base/train_fold.parquet
  candidate_info_list:
  - name: transition_prob/base
    max_num_candidates: 100
    dir: /kaggle/working/output/cand_unsupervised/transition_prob/base
  - name: ranking_location/sml_cd
    max_num_candidates: 50
    dir: /kaggle/working/output/cand_unsupervised/ranking_location/sml_cd
  - name: ranking_location/lrg_cd
    max_num_candidates: 50
    dir: /kaggle/working/output/cand_unsupervised/ranking_location/lrg_cd
  transition_prob_path: /kaggle/working/output/cand_unsupervised/transition_prob/base/yad2yad_feature.parquet
  yad_feature_paths:
  - output/cand_unsupervised/ranking

In [3]:
import os
import sys
from pathlib import Path

import hydra
import polars as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from sklearn.preprocessing import OrdinalEncoder

from utils.data import convert_to_32bit
from utils.load import load_label_data, load_log_data, load_yad_data
from utils.logger import get_logger

numerical_cols = [  # あとで書き換えるので注意
    "total_room_cnt",
    "wireless_lan_flg",
    "onsen_flg",
    "kd_stn_5min",
    "kd_bch_5min",
    "kd_slp_5min",
]

categorical_cols = [
    "yad_type",
    "wid_cd",
    "ken_cd",
    "lrg_cd",
    "sml_cd",
]

logger = None
ordinal_encoder = None


def load_yad_data_with_features(cfg):
    global numerical_cols
    yad_df = load_yad_data(Path(cfg.dir.data_dir))
    original_cols = yad_df.columns
    for path in cfg.exp.yad_feature_paths:
        feature_df = pl.read_parquet(path)
        yad_df = yad_df.join(feature_df, on="yad_no")
    new_cols = [col for col in yad_df.columns if col not in original_cols]
    numerical_cols = list(set(numerical_cols) | set(new_cols))
    return yad_df

In [4]:
def load_and_union_candidates(cfg, mode: str):
    # logデータのsession中のyad_noを候補に加える
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    df = log_df.group_by("session_id").agg(pl.col("yad_no").alias("candidates"))
    dfs = [df]
    for candidate_info in cfg.exp.candidate_info_list:
        df = pl.read_parquet(Path(candidate_info["dir"]) / f"{mode}_candidate.parquet")
        df = df.with_columns(
            pl.col("candidates")
            .list.head(candidate_info["max_num_candidates"])
            .alias("candidates")
        ).filter(pl.col("candidates").list.len() > 0)

        dfs.append(df)
    df = pl.concat(dfs)
    df = (
        df.group_by("session_id")
        .agg(pl.col("candidates").flatten())
        .with_columns(pl.col("candidates").list.unique())
    ).select(["session_id", "candidates"])

    # リストを展開
    candidate_df = df.explode("candidates")

    # セッション最後のyad_noを除外
    last_df = (
        load_log_data(Path(cfg.dir.data_dir), mode)
        .group_by("session_id")
        .agg(pl.col("yad_no").last().alias("candidates"))
        .with_columns(pl.lit(True).alias("last"))
        .sort(by="session_id")
    )
    candidate_df = (
        candidate_df.join(last_df, on=["session_id", "candidates"], how="left")
        .filter(pl.col("last").is_null())
        .drop("last")
    )
    return candidate_df


def concat_label_fold(cfg, mode: str, candidate_df):
    """
    train に対して original, label, fold を付与する
    validationのスコア計算時にはoriginalを外して計算を行う
    """
    if mode == "train":
        candidate_df = (
            pl.concat(
                [
                    candidate_df.with_columns(
                        pl.lit(True).alias("original"), pl.lit(False).alias("label")
                    ),
                    load_label_data(Path(cfg.dir.data_dir))
                    .with_columns(
                        pl.col("yad_no").alias("candidates"),
                        pl.lit(False).alias("original"),
                        pl.lit(True).alias("label"),
                    )
                    .drop("yad_no"),
                ]
            )
            .group_by(["session_id", "candidates"])
            .agg(pl.sum("original"), pl.sum("label"))
        )
        fold_df = pl.read_parquet(cfg.exp.fold_path)
        candidate_df = candidate_df.join(fold_df, on="session_id")
    return candidate_df


def concat_session_feature(cfg, mode: str, candidate_df: pl.DataFrame):
    """
    # TODO: categorical_colsの情報もあとで追加する
    session_id, seq_no, yad_no に yado.csv を結合して集約し、セッションに関する特徴量を作成する
    """
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    yad_df = load_yad_data_with_features(cfg)
    log_yad_df = log_df.join(yad_df.fill_null(0), on="yad_no")
    log_yad_df = log_yad_df.group_by(by="session_id").agg(
        [pl.sum(col).name.suffix("_session_sum") for col in numerical_cols]
        + [pl.min(col).name.suffix("_session_min") for col in numerical_cols]
        + [pl.max(col).name.suffix("_session_max") for col in numerical_cols]
        + [pl.std(col).name.suffix("_session_std") for col in numerical_cols]
    )

    candidate_df = candidate_df.join(log_yad_df, on="session_id")

    return candidate_df


def concat_candidate_feature(cfg, mode: str, candidate_df: pl.DataFrame):
    """
    # TODO: categorical_colsの情報もあとで追加する
    candidateの特徴量を抽出する
    """
    original_cols = candidate_df.columns

    yad_df = load_yad_data_with_features(cfg)
    candidate_yad_df = candidate_df.join(
        yad_df.select(["yad_no"] + numerical_cols + categorical_cols),
        left_on="candidates",
        right_on="yad_no",
    )

    new_cols = [col for col in candidate_yad_df.columns if col not in original_cols]
    print(f"new_cols: {new_cols}")
    return candidate_yad_df

In [5]:
mode = "train"
candidate_df = load_and_union_candidates(cfg, mode)
candidate2_df = concat_label_fold(cfg, mode, candidate_df)
candidate3_df = concat_candidate_feature(cfg, mode, candidate2_df)
candidate3_df.head()

new_cols: ['wireless_lan_flg', 'kd_bch_5min', 'rank_ranking_location/lrg_cd', 'counts_ranking_location/sml_cd', 'rank_ranking/base', 'rank_ranking_location/ken_cd', 'total_room_cnt', 'counts_ranking_location/ken_cd', 'counts_ranking_location/wid_cd', 'rank_ranking_location/wid_cd', 'onsen_flg', 'kd_slp_5min', 'rank_ranking_location/sml_cd', 'kd_stn_5min', 'counts_ranking_location/lrg_cd', 'counts_ranking/base', 'yad_type', 'wid_cd', 'ken_cd', 'lrg_cd', 'sml_cd']


session_id,candidates,original,label,fold,wireless_lan_flg,kd_bch_5min,rank_ranking_location/lrg_cd,counts_ranking_location/sml_cd,rank_ranking/base,rank_ranking_location/ken_cd,total_room_cnt,counts_ranking_location/ken_cd,counts_ranking_location/wid_cd,rank_ranking_location/wid_cd,onsen_flg,kd_slp_5min,rank_ranking_location/sml_cd,kd_stn_5min,counts_ranking_location/lrg_cd,counts_ranking/base,yad_type,wid_cd,ken_cd,lrg_cd,sml_cd
str,i64,u32,u32,i64,f64,f64,f64,u32,f64,f64,f64,u32,u32,f64,i64,f64,f64,f64,u32,u32,i64,str,str,str,str
"""40614d55c5d0b5…",8347,1,0,1,1.0,,39.0,69,3182.0,436.5,161.0,69,69,747.5,0,,12.0,1.0,69,69,0,"""46e33861f921c3…","""107c7305a74c8d…","""c9d5e891463e53…","""7cf2b4f31fb207…"
"""40614d55c5d0b5…",13593,1,0,1,1.0,,26.5,85,2460.0,348.0,112.0,85,85,602.5,0,,8.5,1.0,85,85,0,"""46e33861f921c3…","""107c7305a74c8d…","""c9d5e891463e53…","""7cf2b4f31fb207…"
"""554d8619ba352c…",4216,1,0,0,1.0,,15.0,71,3064.5,105.5,30.0,71,71,725.0,0,,15.0,,71,71,0,"""46e33861f921c3…","""572d60f0f5212a…","""2e63024b11908f…","""d075eb4a966945…"
"""554d8619ba352c…",4915,1,0,0,1.0,,28.0,88,2340.5,333.0,107.0,88,88,578.0,0,,15.0,1.0,88,88,0,"""46e33861f921c3…","""107c7305a74c8d…","""aabf8b3cf64147…","""c3b55cb211c69d…"
"""5ee2f4ed6a5c68…",4024,1,0,1,1.0,,21.0,35,5696.0,70.5,50.0,35,35,367.5,1,,21.0,,35,35,0,"""d86102dd9c232b…","""3831f43bb997a3…","""8945cbc9f218eb…","""0add1dfa772542…"


In [10]:

def concat_session_candidate_feature(cfg, mode: str, candidate_df: pl.DataFrame):
    """
    session中の特徴とcandidateの関係性を特徴量として抽出する
    例: session中におけるcandidateの出現回数(割合)、candidateと同一地域のものを見た回数(割合)
    """
    original_cols = candidate_df.columns
    print(original_cols)

    # 同じ categorical の出現回数
    ## (series_id, categorical) でグループ化して、session_id ごとに出現回数を集計する
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    yad_df = load_yad_data(Path(cfg.dir.data_dir))
    log_yad_df = log_df.join(yad_df.fill_null(0), on="yad_no")
    for col in categorical_cols:
        tmp = (
            log_yad_df.group_by(by=["session_id", col])
            .agg(pl.count("session_id").alias(f"same_{col}_count"))
            .with_columns(
                pl.col(f"same_{col}_count").sum().over("session_id").alias("seq_sum")
            )
            .with_columns(
                (pl.col(f"same_{col}_count") / pl.col("seq_sum")).alias(
                    f"same_{col}_rate"
                )
            )
        )
        candidate_df = candidate_df.join(
            tmp.select(["session_id", col, f"same_{col}_count", f"same_{col}_rate"]),
            on=["session_id", col],
            how="left",
        )

    # transition probを追加
    yad2yad_prob = pl.read_parquet(cfg.exp.transition_prob_path)
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    last_log_df = (
        log_df.group_by("session_id")
        .agg(pl.all().sort_by("seq_no").last())
        .sort(by="session_id")
        .with_columns(pl.col("yad_no").alias("from_yad_no"))
    ).select(["session_id", "from_yad_no"])
    last_log_prob_df = last_log_df.join(yad2yad_prob, on="from_yad_no")
    candidate_df = candidate_df.join(
        last_log_prob_df,
        left_on=["session_id", "candidates"],
        right_on=["session_id", "to_yad_no"],
        how="left",
    ).drop("from_yad_no")

    # last 以外からのtransition probも追加
    yad2yad_prob = pl.read_parquet(cfg.exp.transition_prob_path)
    prob_col = "transition_prob_transition_prob/base"
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    log_df = (
        log_df.sort(by="session_id").with_columns(pl.col("yad_no").alias("from_yad_no"))
    ).select(["session_id", "from_yad_no"])
    log_df = (
        log_df.join(yad2yad_prob, on="from_yad_no")
        .group_by(["session_id", "to_yad_no"])
        .agg(pl.sum(prob_col).alias(prob_col + "_from_all"))
    )
    candidate_df = candidate_df.join(
        log_df,
        left_on=["session_id", "candidates"],
        right_on=["session_id", "to_yad_no"],
        how="left",
    ).drop("from_yad_no")

    # 増えたカラムを出力
    new_cols = [col for col in candidate_df.columns if col not in original_cols]
    print(f"new_cols: {new_cols}")

    return candidate_df

candidate4_df = concat_session_candidate_feature(cfg, mode, candidate3_df)

['session_id', 'candidates', 'original', 'label', 'fold', 'wireless_lan_flg', 'kd_bch_5min', 'rank_ranking_location/lrg_cd', 'counts_ranking_location/sml_cd', 'rank_ranking/base', 'rank_ranking_location/ken_cd', 'total_room_cnt', 'counts_ranking_location/ken_cd', 'counts_ranking_location/wid_cd', 'rank_ranking_location/wid_cd', 'onsen_flg', 'kd_slp_5min', 'rank_ranking_location/sml_cd', 'kd_stn_5min', 'counts_ranking_location/lrg_cd', 'counts_ranking/base', 'yad_type', 'wid_cd', 'ken_cd', 'lrg_cd', 'sml_cd']
new_cols: ['same_yad_type_count', 'same_yad_type_rate', 'same_wid_cd_count', 'same_wid_cd_rate', 'same_ken_cd_count', 'same_ken_cd_rate', 'same_lrg_cd_count', 'same_lrg_cd_rate', 'same_sml_cd_count', 'same_sml_cd_rate', 'transition_prob_transition_prob/base', 'transition_prob_transition_prob/base_from_all']


In [11]:
candidate4_df

session_id,candidates,original,label,fold,wireless_lan_flg,kd_bch_5min,rank_ranking_location/lrg_cd,counts_ranking_location/sml_cd,rank_ranking/base,rank_ranking_location/ken_cd,total_room_cnt,counts_ranking_location/ken_cd,counts_ranking_location/wid_cd,rank_ranking_location/wid_cd,onsen_flg,kd_slp_5min,rank_ranking_location/sml_cd,kd_stn_5min,counts_ranking_location/lrg_cd,counts_ranking/base,yad_type,wid_cd,ken_cd,lrg_cd,sml_cd,same_yad_type_count,same_yad_type_rate,same_wid_cd_count,same_wid_cd_rate,same_ken_cd_count,same_ken_cd_rate,same_lrg_cd_count,same_lrg_cd_rate,same_sml_cd_count,same_sml_cd_rate,transition_prob_transition_prob/base,transition_prob_transition_prob/base_from_all
str,i64,u32,u32,i64,f64,f64,f64,u32,f64,f64,f64,u32,u32,f64,i64,f64,f64,f64,u32,u32,i64,str,str,str,str,u32,f64,u32,f64,u32,f64,u32,f64,u32,f64,f64,f64
"""40614d55c5d0b5…",8347,1,0,1,1.0,,39.0,69,3182.0,436.5,161.0,69,69,747.5,0,,12.0,1.0,69,69,0,"""46e33861f921c3…","""107c7305a74c8d…","""c9d5e891463e53…","""7cf2b4f31fb207…",1,1.0,1,1.0,1,1.0,1,1.0,1,1.0,0.066667,0.066667
"""40614d55c5d0b5…",13593,1,0,1,1.0,,26.5,85,2460.0,348.0,112.0,85,85,602.5,0,,8.5,1.0,85,85,0,"""46e33861f921c3…","""107c7305a74c8d…","""c9d5e891463e53…","""7cf2b4f31fb207…",1,1.0,1,1.0,1,1.0,1,1.0,1,1.0,0.066667,0.066667
"""554d8619ba352c…",4216,1,0,0,1.0,,15.0,71,3064.5,105.5,30.0,71,71,725.0,0,,15.0,,71,71,0,"""46e33861f921c3…","""572d60f0f5212a…","""2e63024b11908f…","""d075eb4a966945…",1,1.0,1,1.0,1,1.0,,,,,0.00885,0.00885
"""554d8619ba352c…",4915,1,0,0,1.0,,28.0,88,2340.5,333.0,107.0,88,88,578.0,0,,15.0,1.0,88,88,0,"""46e33861f921c3…","""107c7305a74c8d…","""aabf8b3cf64147…","""c3b55cb211c69d…",1,1.0,1,1.0,,,,,,,0.00885,0.00885
"""5ee2f4ed6a5c68…",4024,1,0,1,1.0,,21.0,35,5696.0,70.5,50.0,35,35,367.5,1,,21.0,,35,35,0,"""d86102dd9c232b…","""3831f43bb997a3…","""8945cbc9f218eb…","""0add1dfa772542…",1,1.0,1,1.0,1,1.0,1,1.0,1,1.0,,
"""5461f93a8ed2a5…",179,1,0,0,1.0,,9.0,86,2418.5,26.0,76.0,86,86,281.5,1,,9.0,,86,86,0,"""dc414a17890cfc…","""31a0f630d36db5…","""f11d0a982fcea0…","""199073cb3739d7…",2,1.0,2,1.0,2,1.0,2,1.0,2,1.0,0.125,0.397727
"""5461f93a8ed2a5…",2004,1,0,0,,,51.0,1,13207.0,245.5,10.0,1,1,1584.0,1,,51.0,,1,1,0,"""dc414a17890cfc…","""31a0f630d36db5…","""f11d0a982fcea0…","""199073cb3739d7…",2,1.0,2,1.0,2,1.0,2,1.0,2,1.0,,
"""a5bd02d98d1beb…",1487,1,0,1,1.0,,4.0,273,179.0,23.5,102.0,273,273,70.5,0,,2.0,1.0,273,273,0,"""46e33861f921c3…","""107c7305a74c8d…","""e2034d4f2fbe08…","""086904b20a91b5…",3,1.0,3,1.0,3,1.0,3,1.0,3,1.0,,
"""a5bd02d98d1beb…",6157,1,0,1,1.0,,45.5,114,1547.5,216.0,150.0,114,114,404.5,0,,35.0,1.0,114,114,0,"""46e33861f921c3…","""107c7305a74c8d…","""e2034d4f2fbe08…","""086904b20a91b5…",3,1.0,3,1.0,3,1.0,3,1.0,3,1.0,,
"""a5bd02d98d1beb…",8793,1,0,1,1.0,,28.0,136,1075.5,145.5,135.0,136,136,292.0,0,,21.0,1.0,136,136,0,"""46e33861f921c3…","""107c7305a74c8d…","""e2034d4f2fbe08…","""086904b20a91b5…",3,1.0,3,1.0,3,1.0,3,1.0,3,1.0,,
