# yado2vec - atmaCup16 with RECRUIT

- gensimでitem2vec
- 各itemに対し類似度が上位のアイテムK件と、すべてのベクトルデータを保存

In [1]:
!pip install gensim



In [2]:
import pandas as pd
import numpy as np

In [3]:
class CFG:
    seed = 127
    top_n = 50
    
    # model
    vector_size = 64
    epochs = 50
    min_count = 1
    workers =4

## Load data

In [4]:
train_logs_df = pd.read_csv("/kaggle/input/atmacup16-recruit/train_log.csv")
test_logs_df = pd.read_csv("/kaggle/input/atmacup16-recruit/test_log.csv")

## Preprocess

In [5]:
# train, testのログデータをconcat
logs_df = pd.concat([train_logs_df, test_logs_df], axis=0)
logs_df = logs_df.drop(columns=["seq_no"])

# session_idごとにyad_noをリストにまとめる
sentences = logs_df.groupby("session_id")["yad_no"].apply(list).tolist()

In [6]:
sentences[:10]

[[2395],
 [3560, 1959],
 [13535],
 [123],
 [11984],
 [757, 8922],
 [8475],
 [96, 898],
 [6868],
 [8602]]

## Fit

In [7]:
from gensim.models.word2vec import Word2Vec

w2v_params = {
    "vector_size": CFG.vector_size,
    "epochs": CFG.epochs,
    "seed": CFG.seed,
    "min_count": CFG.min_count,
    "workers": CFG.workers,
}

model = Word2Vec(sentences, **w2v_params)

## 定性評価

In [8]:
yados_df = pd.read_csv("/kaggle/input/atmacup16-recruit/yado.csv")

In [9]:
target_yad_no = 11089
res = model.wv.most_similar(target_yad_no, topn=10)

res

[(6800, 0.9979557394981384),
 (3432, 0.9975354075431824),
 (3782, 0.9974309802055359),
 (6799, 0.9973302483558655),
 (9609, 0.9969385266304016),
 (7211, 0.9969350099563599),
 (4771, 0.9966328740119934),
 (5671, 0.9960944652557373),
 (4516, 0.9958534240722656),
 (11690, 0.9940141439437866)]

In [10]:
similar_items = [v for v, _ in res]
similar_items.insert(0, target_yad_no)

similar_items

[11089, 6800, 3432, 3782, 6799, 9609, 7211, 4771, 5671, 4516, 11690]

In [11]:
yados_df[yados_df["yad_no"].isin(similar_items)]

Unnamed: 0,yad_no,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
3431,3432,0,38.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
3781,3782,0,161.0,1.0,0,,,,,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
4515,4516,0,35.0,1.0,0,,,,,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
4770,4771,0,110.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
5670,5671,0,84.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
6798,6799,0,190.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
6799,6800,0,62.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
7210,7211,0,245.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
9608,9609,0,29.0,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706
11088,11089,0,,1.0,0,1.0,,,1.0,46e33861f921c3e38b81998fbf283f01,83522288daa2f3a0010f79df27c38ea5,848dd393c353d8d0e63080af42d99e49,95eac355ee029b439925bf7098836706


## Infer

In [12]:
# 近いアイテム
similar_items = {}
vecs = {}
for yad_no in yados_df["yad_no"].values:
    if yad_no not in model.wv.key_to_index:
        continue
        
    res = model.wv.most_similar(yad_no, topn=CFG.top_n)
    
    # 類似アイテム
    similar_items[yad_no] = [v for v, _ in res]
    
    # ベクトルデータ
    vecs[yad_no] = model.wv[yad_no]

In [13]:
print(f"yad.csvに含まれるアイテム数: {yados_df['yad_no'].nunique()}")
print(f"yad2vecに含まれるアイテム数: {len(similar_items)}")

yad.csvに含まれるアイテム数: 13806
yad2vecに含まれるアイテム数: 13562


In [14]:
similar_items_df = pd.DataFrame.from_dict(similar_items, orient="index")
similar_items_df.columns = [f"item{v}" for v in range(CFG.top_n)]

# 保存
similar_items_df.to_csv("y2v_items.csv", index_label="Key")

In [15]:
similar_items_df

Unnamed: 0,item0,item1,item2,item3,item4,item5,item6,item7,item8,item9,...,item40,item41,item42,item43,item44,item45,item46,item47,item48,item49
1,10870,5198,1254,1503,6192,3796,10950,1474,9789,10248,...,2245,3595,1400,2988,4221,8787,13647,8207,4224,3350
2,13783,3860,12162,3847,299,7597,12232,36,7765,13333,...,5478,13174,440,10590,11080,446,6509,11459,7177,8375
3,6579,9624,846,10439,11295,12154,10415,420,5294,6247,...,12178,5941,2974,4825,3931,2284,12646,9545,7561,1229
4,430,3054,11408,10599,8899,9041,10186,9631,1896,8617,...,9545,2018,1779,4879,11445,8609,7093,2574,3137,4829
5,12782,1901,117,9493,5606,13654,10549,7277,6241,3928,...,3030,11067,5655,7602,4324,3137,12684,8919,11445,1983
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13801,12634,205,10601,2747,627,5909,4674,9310,9448,9198,...,3401,10767,119,7891,10150,496,3579,7182,1841,12672
13803,12752,1362,6883,12558,12775,7251,9470,3512,8724,7558,...,9161,5241,9240,1408,1052,3879,253,13092,6074,11330
13804,12835,8908,11633,1174,10809,518,376,13382,3389,4416,...,8848,7903,9244,2983,9584,10137,11383,1450,12289,8582
13805,12692,2917,5068,10359,4635,6615,3413,2906,11610,61,...,11053,11582,5257,433,6919,5637,3788,8195,11020,11325


In [16]:
vecs_df = pd.DataFrame.from_dict(vecs, orient="index")
vecs_df.columns = [f"vec{v}" for v in range(CFG.vector_size)]

# 保存
vecs_df.to_csv("y2v_vecs.csv", index_label="Key")

In [17]:
vecs_df

Unnamed: 0,vec0,vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8,vec9,...,vec54,vec55,vec56,vec57,vec58,vec59,vec60,vec61,vec62,vec63
1,-0.056160,0.461323,-0.246454,-0.097147,-0.747485,0.106123,-0.327350,-0.604765,-0.297258,0.207103,...,-0.157627,0.238811,-0.049458,0.214988,0.130053,0.306387,-0.045955,0.308469,-0.010767,0.501859
2,-0.380085,0.017703,-0.180225,-0.431436,-0.240020,0.061820,0.259663,-0.285451,0.108400,0.000869,...,-0.297499,0.388696,0.510876,0.191163,-0.259338,0.047039,-0.535272,0.248897,-0.288526,-0.151147
3,0.159336,-0.469314,0.011676,-1.861795,-0.705575,0.016262,0.661346,-0.469468,0.279515,-0.695082,...,0.283484,-0.366958,1.687001,0.727565,0.043843,0.316751,-0.492664,-0.411019,-0.298784,-0.889309
4,-0.304735,-0.183908,-0.284408,-1.307295,-0.493841,-0.378975,0.623894,-0.324846,0.180742,-0.619462,...,-0.279539,0.340182,1.288328,0.591720,-0.228116,-0.138246,-0.428144,-0.731898,-0.320868,-0.465543
5,-0.037937,0.050301,-0.093290,-0.216820,-0.272093,-0.036300,0.226241,-0.000820,0.580219,-0.732586,...,0.127412,0.818064,0.852673,0.914156,-0.192551,0.087662,-0.227779,-0.336038,0.195412,-0.117620
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13801,-0.408353,0.464413,0.379729,-0.791758,-0.762315,-0.129494,0.942803,-0.683183,-0.297677,0.141463,...,-0.167150,0.682974,0.646733,-0.749047,-0.358897,-0.264624,-0.363424,0.111278,0.726910,-0.574798
13803,-0.231795,0.198019,0.330224,-0.366481,-0.144736,0.045544,-0.263635,-0.306815,0.101774,0.238183,...,-0.884420,0.493421,0.477872,0.596923,0.124741,-0.118186,-0.222710,-0.187648,0.094433,0.400136
13804,0.076345,0.452165,0.019554,-0.028466,0.079052,0.281892,0.566528,-0.403330,0.042952,-0.933820,...,-0.231951,0.103499,1.004862,0.509584,-0.047070,0.019413,-0.444506,0.065724,-0.498774,-0.253588
13805,-0.059402,0.010639,-0.064069,-0.123300,-0.155746,0.096167,0.010204,-0.181731,-0.033994,-0.062645,...,0.001008,0.111483,0.241450,0.114779,0.037387,0.088836,-0.048148,-0.055917,0.095358,-0.075880
