In [2]:
%cd ..

/kaggle/working


In [56]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../cand_supervised/te_transition_prob"):
    cfg = compose(config_name="config.yaml", overrides=["exp=base"])
    print(OmegaConf.to_yaml(cfg))

debug: false
seed: 7
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  fold_path: /kaggle/working/output/datasets/make_cv/base/train_fold.parquet
  only_last: true
  num_candidate: 100
  k:
  - 1
  - 5
  - 10
  - 50
  - 100



In [57]:
import os
import sys
from pathlib import Path

import hydra
import numpy as np
import polars as pl
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm

import utils
import wandb
from utils.load import (
    load_image_embeddings,
    load_label_data,
    load_log_data,
    load_session_data,
    load_yad_data,
)
from utils.metrics import calculate_metrics

2種類
- lastからの遷移
- 全体からlastへの遷移

In [80]:
train_log_df = load_log_data(Path(cfg.dir.data_dir), "train")
test_log_df = load_log_data(Path(cfg.dir.data_dir), "test")
fold_df = pl.read_parquet(cfg.exp.fold_path)
train_log_df = train_log_df.join(fold_df, on="session_id")
train_log_df.head()

session_id,seq_no,yad_no,fold
str,i64,i64,i64
"""000007603d533d…",0,2395,2
"""0000ca043ed437…",0,13535,2
"""0000d4835cf113…",0,123,0
"""0000fcda1ae1b2…",0,8475,4
"""000104bdffaaad…",0,96,3


In [59]:
def make_transition_prob(log_df, label_df, only_last=True):
    if only_last:
        log_df = (
            log_df.group_by("session_id")
            .agg(pl.all().sort_by("seq_no").last())
            .sort(by="session_id")
        )
    # labelを付与
    log_df = log_df.join(
        label_df.with_columns(pl.col("yad_no").alias("label")),
        on=["session_id"],
        how="left",
    ).with_columns(
        pl.col("yad_no").alias("from_yad_no"),
        pl.col("label").alias("to_yad_no"),
    )
    # 集約して確率計算
    transition_df = (
        log_df.group_by(["from_yad_no", "to_yad_no"])
        .agg(pl.col("from_yad_no").count().alias("from_to_count"))
        .with_columns(
            pl.col("from_to_count").sum().over(["from_yad_no"]).alias("from_count"),
        )
        .with_columns(
            (pl.col("from_to_count") / pl.col("from_count")).alias("transition_prob")
        )
        .sort(by=["from_yad_no", "to_yad_no"])
        .select(["from_yad_no", "to_yad_no", "transition_prob"])
    )
    return transition_df

In [60]:
train_label_df = load_label_data(Path(cfg.dir.data_dir), "train")

In [73]:
train_transtion_dfs = []
# クロスバリデーションのfoldごとにtarget encodingをする
for fold in range(train_log_df["fold"].n_unique()):
    train_fold_df = train_log_df.filter(pl.col("fold") != fold)
    valid_fold_df = train_log_df.filter(pl.col("fold") == fold)

    # train_fold_df で、valid_fold_df 用の 遷移確率特徴と候補を生成する
    transition_df = make_transition_prob(
        train_fold_df, train_label_df, only_last=cfg.exp.only_last
    )
    transition_df = transition_df.with_columns(
        pl.lit(fold).cast(pl.Int64).alias("fold")
    )  # 特定foldの特徴であることを明示する
    train_transtion_dfs.append(transition_df)
train_trainsition_df = pl.concat(train_transtion_dfs)

In [74]:
# test 用にtrain全体でターゲットエンコーディングする
test_transition_df = make_transition_prob(
    train_log_df, train_label_df, only_last=cfg.exp.only_last
)

test_transition_df.head()

from_yad_no,to_yad_no,transition_prob
i64,i64,f64
2,36,0.05
2,217,0.05
2,299,0.05
2,1099,0.05
2,2200,0.05


In [77]:
def make_candidate(session_df, log_df, transition_df, mode: str, only_last=True):
    if only_last:
        log_df = (
            log_df.group_by("session_id")
            .agg(pl.all().sort_by("seq_no").last())
            .sort(by="session_id")
        )
    # probを付与
    if mode == "train":  # trainはfoldごとに異なる
        log_df = log_df.join(
            transition_df,
            left_on=["yad_no", "fold"],
            right_on=["from_yad_no", "fold"],
            how="inner",
        )
    elif mode == "test":
        log_df = log_df.join(
            transition_df,
            left_on=["yad_no"],
            right_on=["from_yad_no"],
            how="inner",
        )

    # 遷移確率を結合し、確率の降順に候補として生成する
    candidate_df = (
        log_df.group_by(["session_id", "to_yad_no"])  # all用に to_yad_noが複数あるときに対応するため集約
        .agg(pl.sum("transition_prob"))
        .sort(by=["session_id", "transition_prob"], descending=True)
        .group_by("session_id")
        .agg(pl.col("to_yad_no").alias("candidates"))
    )
    candidate_df = session_df.join(
        candidate_df, on="session_id", how="left"
    ).with_columns(
        # candidates が null の場合は空のリストを入れておく
        pl.when(pl.col("candidates").is_null())
        .then(pl.Series("empty", [[]]))
        .otherwise(pl.col("candidates"))
        .alias("candidates")
    )
    return candidate_df

In [82]:
test_session_df = load_session_data(Path(cfg.dir.data_dir), "test")
make_candidate(
    test_session_df,
    test_log_df,
    test_transition_df,
    "test",
    only_last=cfg.exp.only_last,
)

session_id,candidates
str,list[i64]
"""00001149e9c739…","[4714, 11561, … 7902]"
"""0000e02747d749…","[4066, 143, … 13249]"
"""0000f17ae26282…","[10485, 7710, … 410]"
"""000174a6f7a569…","[3359, 12341, … 2047]"
"""00017e2a527901…","[9020, 4070, … 11910]"
"""00018613341f84…","[13292, 3811, … 13549]"
"""00027c33bbdb2e…",[11776]
"""0002f6aa27bcf9…","[13347, 2806, … 2824]"
"""000300aea0d549…","[3901, 4522, … 12217]"
"""00034cba60c960…","[1013, 11450, … 1563]"


In [78]:
train_session_df = load_session_data(Path(cfg.dir.data_dir), "train")

session_id,candidates
str,list[i64]
"""000007603d533d…","[2808, 11882, 3324]"
"""0000ca043ed437…","[8253, 9881, … 1092]"
"""0000d4835cf113…","[9039, 5238, … 4355]"
"""0000fcda1ae1b2…","[626, 755, … 7872]"
"""000104bdffaaad…","[3894, 7749, … 4072]"
"""00011afe25c343…","[12544, 4823, 10510]"
"""000125c737df18…","[2480, 10378, … 9597]"
"""0001763050a10b…","[4744, 7681, … 10544]"
"""000178c4d4d567…","[12432, 3802, … 4962]"
"""0001e6a407a85d…","[10478, 379, … 7050]"


## 確認

In [13]:
log_df = load_log_data(Path(cfg.dir.data_dir), "train")
yad_df = load_yad_data(Path(cfg.dir.data_dir))

In [29]:
log_df.group_by("session_id").agg(pl.col("seq_no").max())["seq_no"].value_counts()

seq_no,counts
i64,u32
6,65
1,82793
7,18
4,833
9,1
2,15350
3,4025
0,185386
5,223
8,4


In [16]:
count_df = log_df.get_column("yad_no").value_counts().sort(by="counts", descending=True)
yad_counts_df = yad_df.join(count_df, on="yad_no").with_columns(
    pl.col("counts").rank(descending=True).over(cfg.exp.location_col).alias("rank")
)

In [21]:
yad_counts_df.select(["yad_no", cfg.exp.location_col, "counts", "rank"]).sort(
    by=[cfg.exp.location_col, "rank"]
)

yad_no,sml_cd,counts,rank
i64,str,u32,f64
10163,"""00e15b2eac75d3…",84,1.0
3714,"""00e15b2eac75d3…",75,2.0
1055,"""00e15b2eac75d3…",49,3.0
1664,"""00e15b2eac75d3…",44,4.0
12490,"""00e15b2eac75d3…",38,5.0
8098,"""00e15b2eac75d3…",34,6.0
4958,"""00e15b2eac75d3…",32,7.0
9266,"""00e15b2eac75d3…",29,8.0
708,"""00e15b2eac75d3…",28,9.0
4605,"""00e15b2eac75d3…",26,10.0


In [24]:
label_df = load_label_data(Path(cfg.dir.data_dir), "train")
count_label_df = (
    label_df.get_column("yad_no").value_counts().sort(by="counts", descending=True)
)
yad_label_counts_df = yad_df.join(count_label_df, on="yad_no").with_columns(
    pl.col("counts").rank(descending=True).over(cfg.exp.location_col).alias("rank")
)
yad_label_counts_df.select(["yad_no", cfg.exp.location_col, "counts", "rank"]).sort(
    by=[cfg.exp.location_col, "rank"]
)

yad_no,sml_cd,counts,rank
i64,str,u32,f64
10163,"""00e15b2eac75d3…",49,1.0
3714,"""00e15b2eac75d3…",44,2.0
12490,"""00e15b2eac75d3…",41,3.0
1055,"""00e15b2eac75d3…",35,4.0
1664,"""00e15b2eac75d3…",33,5.0
9266,"""00e15b2eac75d3…",25,6.0
4605,"""00e15b2eac75d3…",23,7.0
1276,"""00e15b2eac75d3…",18,8.5
10689,"""00e15b2eac75d3…",18,8.5
8568,"""00e15b2eac75d3…",17,10.0
