In [1]:
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

plt.style.use("ggplot")

In [2]:
class CFG:
    name = "exp001"
    seed = 42

    path_input = Path("../input")
    path_output = Path("../output")

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True


def setup(CFG):
    # expフォルダを作成
    CFG.path_exp = CFG.path_output / CFG.name
    CFG.path_exp.mkdir(parents=True, exist_ok=True)

    # seedの設定
    set_seed(CFG.seed)

    return CFG

In [4]:
CFG = setup(CFG)

In [5]:
# データの読み込み
train_log_df = pd.read_csv(CFG.path_input / "train_log.csv")
train_label_df = pd.read_csv(CFG.path_input / "train_label.csv")

test_log_df = pd.read_csv(CFG.path_input / "test_log.csv")
test_session_df = pd.read_csv(CFG.path_input / "test_session.csv")

yado_df = pd.read_csv(CFG.path_input / "yado.csv")
yado_embedding = pd.read_parquet(CFG.path_input / "image_embeddings.parquet")

In [6]:
# train_logとlabelを結合
train_label_df["label"] = 1
train_log_df = train_log_df.merge(
    train_label_df, how="left", on=["session_id", "yad_no"]
)
train_log_df["label"] = train_log_df["label"].fillna(0)
train_log_df = train_log_df.merge(yado_df, how="left", on="yad_no")

In [7]:
train_log_df

Unnamed: 0,session_id,seq_no,yad_no,label,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
0,000007603d533d30453cc45d0f3d119f,0,2395,0.0,0,113.0,1.0,0,,,,,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
1,0000ca043ed437a1472c9d1d154eb49b,0,13535,0.0,0,40.0,1.0,0,1.0,,,1.0,b07b75d367ebece55a23ceecc939fff4,0a66f6ab9c0507059da6f22a0e1f1690,9ab5718fd88c6e5f9fec37a51827d428,7aff71bb47acb796d425c5ed5e6dfb3f
2,0000d4835cf113316fe447e2f80ba1c8,0,123,0.0,0,17.0,1.0,0,,,,,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,dac434451fe9bd50068191f41fe792e3,b7c56c5d2855b39366b4ebe9a4eded93
3,0000fcda1ae1b2f431e55a7075d1f500,0,8475,0.0,0,65.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,107c7305a74c8dcc4f143de208bf7ec2,3a6cd37aa9e38fd96d9dafc2615643d0,f2fcbd8e62872147efde0acef474e1f2
4,000104bdffaaad1a1e0a9ebacf585f33,0,96,1.0,0,228.0,1.0,0,,,,1.0,e9316013ee1b03f4525fe361c46ce9c5,84efa50e52f9b471c95bfc3b21b854ad,a1370d90ed3b80ee41311bbbab46aec9,d72674f02c5340d90f245e3177727650
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
419265,ffffcd5bc19d62cad5a3815c87818d83,0,12230,0.0,0,354.0,1.0,0,,,,1.0,321b69d5eec98fe6253e26b86058e6a9,a2b54b288d51bb19085ed1d99c428397,0c92ce61d0bf83edefee7eea279a15c8,de9c306d6999d60160eaf17cdb20fe47
419266,ffffcd5bc19d62cad5a3815c87818d83,1,10619,1.0,0,,1.0,0,,,,1.0,321b69d5eec98fe6253e26b86058e6a9,a2b54b288d51bb19085ed1d99c428397,0c92ce61d0bf83edefee7eea279a15c8,de9c306d6999d60160eaf17cdb20fe47
419267,ffffcd5bc19d62cad5a3815c87818d83,2,12230,0.0,0,354.0,1.0,0,,,,1.0,321b69d5eec98fe6253e26b86058e6a9,a2b54b288d51bb19085ed1d99c428397,0c92ce61d0bf83edefee7eea279a15c8,de9c306d6999d60160eaf17cdb20fe47
419268,fffffa7baf370083ebcdd98f26a7e31a,0,2439,1.0,0,81.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,ab9480fd72a44d51690ab16c4ad4d49c


In [8]:
# yadoのtrainにおける予約回数をカウント。value_countsで実質ソートされている
yado_reserve_count = (
    train_log_df.query("label == 1")["yad_no"].value_counts().reset_index()
)
yado_reserve_count = yado_reserve_count.merge(yado_df, how="left", on="yad_no")

In [9]:
yado_reserve_count

Unnamed: 0,yad_no,count,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
0,3338,426,0,703.0,1.0,0,1.0,,,,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,1d9f09b9e2bd43cebc9885a46388739a
1,12350,358,0,696.0,1.0,0,,,,,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,1d9f09b9e2bd43cebc9885a46388739a
2,10095,302,0,2007.0,1.0,0,,,,1.0,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,f7b42d92528e7a88617c4b26e033d3e5
3,719,250,0,600.0,1.0,0,1.0,,,1.0,f0112abf369fb03cdc5f5309300913da,072c85e1653e10c9c7dd065ad007125a,ed62e66a5031c23c78bd03ccf9f3ef70,d3d1cf557f10fadb1fbc0b429bf14578
4,8553,247,0,550.0,1.0,0,,,,1.0,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,1d9f09b9e2bd43cebc9885a46388739a
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9449,8940,1,0,16.0,,1,,,,1.0,c312e07b7a5d456d53a5b00910a336e1,6692a692f80687411022c08e4f5a7a00,8cf750072f8520a726ff601894c2d39e,ae316115cdaf0cb1b67b65e75f260b4d
9450,8670,1,0,120.0,1.0,0,1.0,,,1.0,f0112abf369fb03cdc5f5309300913da,072c85e1653e10c9c7dd065ad007125a,449c52ef581d5f9ef311189469a0520e,f76e14a3f2ebd4c7efd873fe9b5a02fd
9451,5000,1,0,296.0,1.0,0,,,,1.0,f0112abf369fb03cdc5f5309300913da,072c85e1653e10c9c7dd065ad007125a,52d0a7d917cc19ddf5e0ee208f0acfed,5423b90b9624bbb2b47ce18b63fb9a82
9452,10344,1,0,,,0,,,1.0,1.0,8a1c0d3243bba111cbcd1ec6c692dc6d,ce83563814cff3080c8ae076f44b3020,1c3e1864d98151b856ba5f9a8c672e1e,f7778ca7631d62073b0f7feee455545b


In [10]:
test_pivot = test_log_df.pivot_table(index="session_id", columns="seq_no")
test_pivot = test_pivot.fillna(-1).astype(int)
test_log_df = test_log_df.merge(yado_df, how="left", on="yad_no")

In [11]:
test_log_df

Unnamed: 0,session_id,seq_no,yad_no,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
0,00001149e9c73985425197104712478c,0,3560,0,205.0,1.0,0,,,,1.0,46e33861f921c3e38b81998fbf283f01,107c7305a74c8dcc4f143de208bf7ec2,52ca3d2824fc3cc90bd4274423badeed,87d9490219b3778f73c41b8176cf30d0
1,00001149e9c73985425197104712478c,1,1959,0,173.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,107c7305a74c8dcc4f143de208bf7ec2,52ca3d2824fc3cc90bd4274423badeed,87d9490219b3778f73c41b8176cf30d0
2,0000e02747d749a52b7736dfa751e258,0,11984,0,224.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,2e63024b11908f3729510051a6fc7d9e,d075eb4a9669452b8f07cfc0d13a03ab
3,0000f17ae2628237d78d3a38b009d3be,0,757,0,174.0,1.0,0,1.0,,,1.0,f0112abf369fb03cdc5f5309300913da,bd054cc265d68a400ccb976ac69c6463,dca13b5f308a0ae88ab8875a9ab56919,3267093e6bcad4a46af9d3e46350b22f
4,0000f17ae2628237d78d3a38b009d3be,1,8922,0,106.0,1.0,0,1.0,,,1.0,f0112abf369fb03cdc5f5309300913da,bd054cc265d68a400ccb976ac69c6463,dca13b5f308a0ae88ab8875a9ab56919,3267093e6bcad4a46af9d3e46350b22f
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250300,fffee3199ef94b92283239cd5e3534fa,1,8336,0,187.0,1.0,0,1.0,,,1.0,b07b75d367ebece55a23ceecc939fff4,0a66f6ab9c0507059da6f22a0e1f1690,da273b9909edbb8cdb40305868de155c,f3499131dfb1460d1ddc1af10b936c66
250301,ffff62c6bb49bc9c0fbcf08494a4869c,0,12062,0,51.0,1.0,0,1.0,,,1.0,3300cf6f774b7c6a5807110f244cbc21,013592a15b9a689232792f11da797ac7,989ce3ae2fc5f1649bd10e05917a27f8,ed85e7b17b271de96e7e22ab2bff4aa7
250302,ffff9a7dcc892875c7a8b821fa436228,0,8989,0,30.0,,0,,,,1.0,46e33861f921c3e38b81998fbf283f01,c86352f5b57e80fe545cfec1fd8505a1,9d6a46da05976cab8ac2b8583215c665,568887ea1e1d8c3cf3c60b5be585aa6d
250303,ffffb1d30300fe17f661941fd085b04b,0,6030,0,280.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,107c7305a74c8dcc4f143de208bf7ec2,d153c8fd78bfad6faadf8e769e5cb314,93bb8a3bdcfb298251b12efa3067d44f


In [12]:
# session_idごとに最後のsml_cdを取得
sml_cd_last_sml_df = (
    test_log_df.groupby("session_id")["sml_cd"].last().reset_index().reset_index()
)
sml_cd_last_sml_dict = dict(
    zip(sml_cd_last_sml_df["session_id"], sml_cd_last_sml_df["sml_cd"])
)

sml_cd_last_lrg_df = (
    test_log_df.groupby("session_id")["lrg_cd"].last().reset_index().reset_index()
)
sml_cd_last_lrg_dict = dict(
    zip(sml_cd_last_lrg_df["session_id"], sml_cd_last_lrg_df["lrg_cd"])
)

In [13]:
def extract_valid_yados(record):
    """test_logのセッションに含まれているyad_noを抽出"""
    return list(filter(lambda x: isinstance(x, int) and x >= 0, record))

In [14]:
preds = []

for record in tqdm(test_pivot.itertuples(), total=len(test_pivot)):
    session_id = record[0]
    session_yados = extract_valid_yados(record)
    session_yados_set = set(session_yados)

    # 予測値のlistを作る
    session_yados.pop()  # 最後はでない
    tmp_yados = session_yados.copy()

    # 各セッションの最後のsml_cdを取得
    target_sml_cd = sml_cd_last_sml_dict.get(session_id)
    target_lrg_cd = sml_cd_last_sml_dict.get(session_id)
    # 候補を絞る
    canditates_sml = yado_reserve_count.query("sml_cd == @target_sml_cd")[
        "yad_no"
    ].tolist()
    canditates_lrg = yado_reserve_count.query("sml_cd == @target_lrg_cd")[
        "yad_no"
    ].tolist()

    # sml_cdが一致している候補を逐次的に追加。ただし、session_yadosに含まれていないもののみ
    for cand in canditates_sml:
        if cand not in session_yados_set:
            tmp_yados.append(cand)
        if len(tmp_yados) == 10:
            break

    # 次はlrg_cdが一致している候補を逐次的に追加。ただし、session_yadosに含まれていないもののみ
    extra_num = 10 - len(tmp_yados)
    for _ in range(extra_num):
        for cand in canditates_lrg:
            if cand not in session_yados_set:
                tmp_yados.append(cand)
            if len(tmp_yados) == 10:
                break
        if len(tmp_yados) == 10:
            break

    # 足りないところは一番人気のやつを追加
    extra_num = 10 - len(tmp_yados)
    for _ in range(extra_num):
        tmp_yados.append(3338)

    pred = {f"predict_{idx}": cand for idx, cand in enumerate(tmp_yados)}
    preds.append(pred)

pred_df = pd.DataFrame(preds)

100%|██████████| 174700/174700 [03:11<00:00, 913.27it/s]


In [15]:
pred_df.to_csv(CFG.path_exp / "submission.csv", index=False)

In [16]:
pred_df

Unnamed: 0,predict_0,predict_1,predict_2,predict_3,predict_4,predict_5,predict_6,predict_7,predict_8,predict_9
0,3560,11561,5466,2680,10965,9534,4545,2811,10233,6563
1,143,6555,12862,4066,7014,4825,1266,5267,11237,6129
2,757,9190,9910,1774,2267,410,10104,11001,6721,6730
3,13610,12341,277,5657,2795,6991,2047,3359,9319,7413
4,9020,12524,6576,3187,5713,101,13590,5106,11442,2862
...,...,...,...,...,...,...,...,...,...,...
174695,1997,7888,1885,11123,2278,5744,7062,831,7780,8771
174696,1227,5331,6874,13702,4014,3802,12432,2232,13220,4962
174697,6199,11037,12425,12089,12986,2927,10155,12132,7379,6905
174698,6378,10287,3100,11496,2305,2373,8501,2692,1530,5513
