In [1]:
%cd ..

/kaggle/working


In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../generate_datasets/028_limit_seq"):
    cfg = compose(config_name="config.yaml", overrides=["exp=two_001"])
    print(OmegaConf.to_yaml(cfg))

debug: false
seed: 7
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  fold_path: /kaggle/working/output/datasets/make_cv/base/train_fold.parquet
  limit_seq: 2
  candidate_info_list:
  - name: ranking_location/sml_cd
    max_num_candidates: 5
    dir: /kaggle/working/output/cand_unsupervised/ranking_location/sml_cd
  - name: ranking_location/lrg_cd
    max_num_candidates: 5
    dir: /kaggle/working/output/cand_unsupervised/ranking_location/lrg_cd
  - name: ranking_location_all/sml_cd
    max_num_candidates: 30
    dir: /kaggle/working/output/cand_unsupervised/prob_matrix_filter/two002
  transition_prob_path: /kaggle/working/output/cand_unsupervised/transition_prob_fix/base/yad2yad_feature.parquet
  transition_prob_all_path: 

In [3]:
import os
import sys
from pathlib import Path

import hydra
import polars as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from sklearn.preprocessing import OrdinalEncoder

from utils.data import convert_to_32bit
from utils.load import load_label_data, load_log_data, load_yad_data
from utils.logger import get_logger

numerical_cols = [  # あとで書き換えるので注意
    "total_room_cnt",
    "wireless_lan_flg",
    "onsen_flg",
    "kd_stn_5min",
    "kd_bch_5min",
    "kd_slp_5min",
]

categorical_cols = [
    "yad_type",
    "wid_cd",
    "ken_cd",
    "lrg_cd",
    "sml_cd",
]

logger = None
ordinal_encoder = None

In [4]:
def load_limit_log_data(cfg, mode: str):
    log_df = load_log_data(Path(cfg.dir.data_dir), mode)
    if cfg.exp.limit_seq is not None:
        log_df = (
            log_df.group_by("session_id")
            .agg(
                pl.col("seq_no").slice(-cfg.exp.limit_seq, cfg.exp.limit_seq),
                pl.col("yad_no").slice(-cfg.exp.limit_seq, cfg.exp.limit_seq),
            )
            .explode(["yad_no", "seq_no"])
            .sort(by="session_id")
        )
    return log_df

In [5]:
mode = "train"
log_df = load_limit_log_data(cfg, mode)
log_df.head(7)

session_id,seq_no,yad_no
str,i64,i64
"""000007603d533d…",0,2395
"""0000ca043ed437…",0,13535
"""0000d4835cf113…",0,123
"""0000fcda1ae1b2…",0,8475
"""000104bdffaaad…",0,96
"""000104bdffaaad…",1,898
"""00011afe25c343…",0,6868


In [6]:
log_df.filter(pl.col("seq_no") == 9)["session_id"].to_list()

['734cc105dc165cc485341e367b3c70ab']

In [7]:
log_df.filter(pl.col("session_id") == "734cc105dc165cc485341e367b3c70ab")

session_id,seq_no,yad_no
str,i64,i64
"""734cc105dc165c…",8,5116
"""734cc105dc165c…",9,8567


In [21]:
def load_and_union_candidates(cfg, mode: str):
    # logデータのsession中のyad_noを候補に加える
    log_df = load_limit_log_data(cfg, mode)
    df = log_df.group_by("session_id").agg(pl.col("yad_no").alias("candidates"))
    dfs = [df]
    for candidate_info in cfg.exp.candidate_info_list:
        df = pl.read_parquet(Path(candidate_info["dir"]) / f"{mode}_candidate.parquet")
        df = df.with_columns(
            pl.col("candidates")
            .list.head(candidate_info["max_num_candidates"])
            .alias("candidates")
        ).filter(pl.col("candidates").list.len() > 0)
        dfs.append(df)
    df = pl.concat(dfs)
    df = (
        df.group_by("session_id")
        .agg(pl.col("candidates").flatten())
        .with_columns(pl.col("candidates").list.unique())
    ).select(["session_id", "candidates"])

    if cfg.debug:
        df = df.with_columns(pl.col("candidates").list.head(2).alias("candidates"))

    # リストを展開
    candidate_df = df.explode("candidates")

    # セッション最後のyad_noを除外
    last_df = (
        load_limit_log_data(cfg, mode)
        .group_by("session_id")
        .agg(pl.col("yad_no").last().alias("candidates"))
        .with_columns(pl.lit(True).alias("last"))
        .sort(by="session_id")
    )
    candidate_df = (
        candidate_df.join(last_df, on=["session_id", "candidates"], how="left")
        .filter(pl.col("last").is_null())
        .drop("last")
    )
    return candidate_df


candidate_df = load_and_union_candidates(cfg, mode)

In [22]:
candidate_df.head()

session_id,candidates
str,i64
"""8fbc73e2b1d138…",724
"""8fbc73e2b1d138…",2256
"""8fbc73e2b1d138…",2837
"""8fbc73e2b1d138…",3147
"""8fbc73e2b1d138…",3186


In [23]:
# last 以外からのtransition probも追加(prob_matrix_path)
yad2yad_prob = pl.read_parquet(cfg.exp.prob_matrix_path)
prob_col = "transition_prob"
log_df = load_limit_log_data(cfg, mode)
log_df = (
    log_df.sort(by=["session_id", "seq_no"]).with_columns(
        [
            pl.col("yad_no").shift(si).over("session_id").alias(f"yad_no_{si}")
            for si in range(cfg.exp.limit_seq)
        ]
    )
).drop(["yad_no"])
log_df = log_df.group_by("session_id").agg(pl.all().last()).sort(by="session_id")
for si in range(cfg.exp.limit_seq):
    tmp = log_df.join(
        yad2yad_prob, left_on=f"yad_no_{si}", right_on="from_yad_no"
    ).with_columns(pl.col(prob_col).alias(prob_col + f"_prob_matrix_{si}"))

    candidate_df = candidate_df.join(
        tmp.select(["session_id", "to_yad_no", prob_col + f"_prob_matrix_{si}"]),
        left_on=["session_id", "candidates"],
        right_on=["session_id", "to_yad_no"],
        how="left",
    ).drop("from_yad_no")


session_id,seq_no,yad_no_0,yad_no_1
str,i64,i64,i64
"""000007603d533d…",0,2395,
"""0000ca043ed437…",0,13535,
"""0000d4835cf113…",0,123,
"""0000fcda1ae1b2…",0,8475,
"""000104bdffaaad…",1,898,96.0
"""00011afe25c343…",0,6868,
"""000125c737df18…",0,8602,
"""0001763050a10b…",0,13106,
"""000178c4d4d567…",0,12062,
"""0001e6a407a85d…",0,4866,


In [24]:

log_df.head()

session_id,seq_no,yad_no_0,yad_no_1
str,i64,i64,i64
"""000007603d533d…",0,2395,
"""0000ca043ed437…",0,13535,
"""0000d4835cf113…",0,123,
"""0000fcda1ae1b2…",0,8475,
"""000104bdffaaad…",1,898,96.0


In [26]:
candidate_df.head(10)

session_id,candidates,transition_prob_prob_matrix_0,transition_prob_prob_matrix_1
str,i64,f64,f64
"""8fbc73e2b1d138…",724,0.013512,
"""8fbc73e2b1d138…",2256,0.029533,
"""8fbc73e2b1d138…",2837,0.01709,
"""8fbc73e2b1d138…",3147,0.012875,
"""8fbc73e2b1d138…",3186,0.022208,
"""8fbc73e2b1d138…",3279,0.004415,
"""8fbc73e2b1d138…",3507,0.027285,
"""8fbc73e2b1d138…",4053,0.040332,
"""8fbc73e2b1d138…",4940,0.02556,
"""8fbc73e2b1d138…",5268,0.026974,
