In [1]:
%cd ..

/kaggle/working


In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../cand_unsupervised/session2item"):
    cfg = compose(config_name="config.yaml", overrides=["debug=True"])
    print(OmegaConf.to_yaml(cfg))

debug: true
seed: 7
dir:
  data_dir: /kaggle/working/input/atmaCup16_Dataset
  output_dir: /kaggle/working/output
  exp_dir: /kaggle/working/output/exp
  cand_unsupervised_dir: /kaggle/working/output/cand_unsupervised
  cand_supervised_dir: /kaggle/working/output/cand_supervised
  datasets_dir: /kaggle/working/output/datasets
exp:
  num_candidate: 100
  k:
  - 1
  - 5
  - 10
  - 50
  - 100
  implicit:
    model: bpr
    params: null



In [3]:
import os
import sys
from pathlib import Path

import hydra
import numpy as np
import polars as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf

import utils
import wandb
from utils.load import load_label_data, load_log_data, load_session_data
from utils.metrics import calculate_metrics

In [4]:
import implicit
import scipy.sparse as sparse

In [11]:
with utils.timer("load data"):
    train_log_df = load_log_data(Path(cfg.dir.data_dir), "train")
    test_log_df = load_log_data(Path(cfg.dir.data_dir), "test")
    all_log_df = pl.concat([train_log_df, test_log_df])

[load data] done in 1.4 s


In [19]:
# session_id を連番に変換
all_log_df = all_log_df.with_columns(
    pl.col("session_id").cast(pl.Categorical).to_physical().alias("sid"),
)

# sid と yad_no を対応させる dict
unique_df = all_log_df.unique(["sid", "session_id"])
unique_sids = unique_df["sid"].to_numpy()
unique_session_ids = unique_df["session_id"].to_list()
unique_yad_nos = unique_df["yad_no"].unique()

In [21]:
len(unique_sids)

463398

In [46]:
sparse_item_user = sparse.csr_matrix(
    (
        np.ones(len(all_log_df)),
        (all_log_df["sid"].to_numpy(), all_log_df["yad_no"].to_numpy()),
    )
)
if cfg.exp.implicit.model == "bpr":
    from implicit.cpu.bpr import BayesianPersonalizedRanking

    model = BayesianPersonalizedRanking(
        **OmegaConf.to_container(cfg.exp.implicit.params, resolve=True)
    )

In [31]:
if cfg.exp.implicit.model == "bpr":
    from implicit.cpu.bpr import BayesianPersonalizedRanking

    model = BayesianPersonalizedRanking(
        **OmegaConf.to_container(cfg.exp.implicit.params, resolve=True)
    )

In [47]:
model.fit(sparse_item_user)

  0%|          | 0/100 [00:00<?, ?it/s]

In [71]:
session_ids = unique_session_ids
session_vectors = model.user_factors[unique_sids]
session_factor_df = pl.DataFrame({"session_id": session_ids}).with_columns(
    pl.Series(name=f"session_factor_{i}", values=session_vectors[:, i])
    for i in range(session_vectors.shape[1])
)
session_factor_df.head()

session_id,session_factor_0,session_factor_1,session_factor_2,session_factor_3,session_factor_4,session_factor_5,session_factor_6,session_factor_7,session_factor_8,session_factor_9,session_factor_10,session_factor_11,session_factor_12,session_factor_13,session_factor_14,session_factor_15,session_factor_16,session_factor_17,session_factor_18,session_factor_19,session_factor_20,session_factor_21,session_factor_22,session_factor_23,session_factor_24,session_factor_25,session_factor_26,session_factor_27,session_factor_28,session_factor_29,session_factor_30,session_factor_31,session_factor_32,session_factor_33,session_factor_34,session_factor_35,session_factor_36,session_factor_37,session_factor_38,session_factor_39,session_factor_40,session_factor_41,session_factor_42,session_factor_43,session_factor_44,session_factor_45,session_factor_46,session_factor_47,session_factor_48,session_factor_49,session_factor_50,session_factor_51,session_factor_52,session_factor_53,session_factor_54,session_factor_55,session_factor_56,session_factor_57,session_factor_58,session_factor_59,session_factor_60,session_factor_61,session_factor_62,session_factor_63,session_factor_64
object,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
00294065e14b11795e437b6725a31ff8,-0.036771,0.000164,0.026966,-0.012653,0.04514,-0.02769,0.003658,-0.063323,-0.01879,-0.049074,-0.035665,0.033155,-0.051288,-0.066057,-0.002031,-0.003474,0.01524,0.003185,-0.023771,0.031043,0.021323,-0.004541,0.003489,-0.00788,-0.044113,-0.031901,0.016195,-0.01755,-0.004789,0.029545,-0.007863,0.04089,0.041743,-0.059417,-0.002905,0.022054,-0.013864,-0.030507,0.04886,-0.014229,0.039684,0.013793,-0.019421,-0.035503,-0.027573,0.015012,-0.027347,0.022011,-0.047475,0.012779,0.003484,0.0262,-0.014723,0.031156,-0.007428,-0.01241,0.096701,0.026562,0.003461,-0.039273,-0.014747,0.025647,0.03234,-0.024101,1.0
002b091484c024d6b9504eb9c1a3c50a,0.059012,0.060202,5.6e-05,-0.053923,-0.040282,0.033557,0.045072,0.011441,-0.032807,-0.039396,-0.039178,0.017746,0.018871,-0.04083,0.016685,-0.053092,-0.018293,-0.042589,0.036984,-0.025973,-0.020713,-0.035783,-0.077183,-0.000798,-0.029774,0.007561,0.043933,0.064535,0.026312,-0.025284,0.043433,0.035394,0.006871,-0.034281,0.054639,0.043514,0.036593,-0.006707,-0.074407,-0.038002,-0.039956,0.055136,0.023354,0.047811,0.063801,-0.011288,-0.069044,-0.046459,-0.036728,0.028165,-0.032068,0.006776,0.025952,-0.036918,0.00888,0.028221,-0.067406,0.021893,-0.013964,-0.015134,-0.00995,0.042947,0.061038,0.005145,1.0
002e2ece9a05cd20bcc76f95d50aa910,0.002788,0.01935,-0.000218,-0.012647,0.013266,0.005118,-0.007905,0.007355,0.008918,0.005305,0.005698,0.021933,0.000152,0.001966,0.02101,-0.022875,0.043102,-0.003904,0.006291,0.000627,0.015834,0.030637,0.028937,-0.003354,0.010379,0.005265,0.001907,-0.004861,-0.010202,0.020113,-0.012528,-0.004365,0.003212,0.000864,-0.003111,0.021124,0.02415,-0.00216,-0.020098,0.00397,-0.019377,0.015713,-0.005888,-0.012549,0.010502,-0.020998,0.003017,0.000973,-0.002455,-0.016202,-0.022227,-0.007497,-0.011745,0.008557,0.007437,0.016202,0.009612,-0.020468,0.036794,0.017653,0.018153,0.003538,-0.007995,0.010818,1.0
0042d37af6dd345a9a894ca7655c942c,-0.036546,0.040001,-0.014125,-0.051478,0.033564,0.027352,0.015059,0.008505,0.01235,0.016998,0.024763,-0.004508,0.010197,-0.026252,0.024483,0.030378,0.019711,-0.011307,-0.004122,0.042425,0.004763,-0.036433,-0.014888,-0.022988,-0.002358,0.043713,0.061691,0.012252,0.010359,0.009214,0.005947,0.038487,0.000393,0.026606,-0.015697,-0.00516,0.011509,-0.000648,0.003824,0.041818,-0.07622,0.009325,-0.009421,-0.054425,-0.023548,-0.011478,-0.008401,0.056804,0.007975,-0.034334,0.019686,-0.042343,-0.013827,0.002503,0.029714,-0.007683,0.031995,-0.010224,-0.031221,-7.2e-05,0.012726,-0.039655,0.034917,0.018123,1.0
0048fe227980577a2b2047126ecbb786,0.084619,0.048285,0.108595,0.01772,0.031941,-0.018123,0.012628,-0.021612,-0.086186,-0.06223,-0.024024,-0.117529,0.009104,0.096366,0.11838,0.095569,0.006349,-0.070021,-0.116397,0.015271,0.021736,-0.002842,-0.066831,-0.017185,0.087705,0.032662,0.056455,0.019665,-0.015438,-0.003288,-0.02796,0.007036,-0.039092,-0.009971,0.093991,0.035768,0.03279,0.004002,-0.043429,0.018611,0.02609,-0.029287,0.061743,-0.115864,0.016999,0.024041,-0.096511,0.001174,0.031817,-0.009657,-0.036919,0.015383,0.011721,-0.021871,0.026098,0.012601,0.084723,-0.024929,0.09665,-0.026404,-0.081156,0.035022,0.094793,0.013281,1.0


In [75]:
yad_ids = unique_yad_nos
yad_vectors = model.user_factors[unique_yad_nos]
yad_factor_df = pl.DataFrame({"yad_no": yad_ids}).with_columns(
    pl.Series(name=f"yad_factor_{i}", values=yad_vectors[:, i])
    for i in range(yad_vectors.shape[1])
)
yad_factor_df.head()

yad_no,yad_factor_0,yad_factor_1,yad_factor_2,yad_factor_3,yad_factor_4,yad_factor_5,yad_factor_6,yad_factor_7,yad_factor_8,yad_factor_9,yad_factor_10,yad_factor_11,yad_factor_12,yad_factor_13,yad_factor_14,yad_factor_15,yad_factor_16,yad_factor_17,yad_factor_18,yad_factor_19,yad_factor_20,yad_factor_21,yad_factor_22,yad_factor_23,yad_factor_24,yad_factor_25,yad_factor_26,yad_factor_27,yad_factor_28,yad_factor_29,yad_factor_30,yad_factor_31,yad_factor_32,yad_factor_33,yad_factor_34,yad_factor_35,yad_factor_36,yad_factor_37,yad_factor_38,yad_factor_39,yad_factor_40,yad_factor_41,yad_factor_42,yad_factor_43,yad_factor_44,yad_factor_45,yad_factor_46,yad_factor_47,yad_factor_48,yad_factor_49,yad_factor_50,yad_factor_51,yad_factor_52,yad_factor_53,yad_factor_54,yad_factor_55,yad_factor_56,yad_factor_57,yad_factor_58,yad_factor_59,yad_factor_60,yad_factor_61,yad_factor_62,yad_factor_63,yad_factor_64
i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
6182,-0.048326,-0.004528,-0.024669,-0.089255,0.010862,-0.033149,-0.028463,0.057671,-0.060024,0.014419,-0.011375,0.075867,-0.073627,-0.033328,-0.016734,-0.08195,0.062993,0.025594,-0.0225,-0.071677,-0.020026,-0.007773,-0.01956,0.030126,0.041875,-0.011655,0.032269,0.003096,-0.003822,0.029732,0.015008,0.032676,0.003491,-0.05299,-0.002994,0.051818,-0.061261,-0.035934,-0.0166,0.085077,-0.080078,-0.027553,-0.006894,-0.031826,0.071659,0.010862,0.00932,-0.000881,0.086484,0.011526,0.040588,-0.002206,0.009963,-0.021516,0.042936,-0.09884,-0.053337,-0.036408,0.014027,-0.05266,0.02113,-0.077207,-0.017149,-0.013563,1.0
12298,0.014495,0.001375,-0.00929,0.007667,-0.000402,-0.029706,0.000689,-0.001776,-0.010251,-0.002306,-0.005729,-0.005401,0.001226,-0.003998,-0.011896,0.011732,-0.010357,-0.003676,-0.008669,0.013002,0.019534,-0.008416,0.013202,-0.012791,0.000937,0.006825,0.015935,0.002565,0.001324,0.013638,0.001323,0.00539,-0.008881,-0.000687,0.004977,0.0025,-0.00459,-0.008292,-0.001602,-0.010882,0.003404,0.000196,-0.013993,-0.006648,0.002317,-0.000883,-0.01104,0.01576,-0.002638,-0.001733,-0.000834,0.009971,0.009188,-0.009179,0.001683,0.003499,0.004111,-0.001257,0.010511,0.001072,-0.006655,-6.7e-05,-0.002751,0.009406,1.0
2900,0.036425,0.008342,0.011304,-0.023282,-0.032702,0.012484,-0.043675,0.006696,-0.007789,0.010664,0.015559,0.017608,-0.032099,0.035661,0.0378,-0.009083,0.006619,0.016124,0.037166,0.061588,0.018872,-0.003306,0.01516,0.002733,0.023661,-0.001795,-0.014787,0.051053,0.01149,0.006545,0.0115,0.007263,0.047442,0.028293,0.0054,0.005152,-0.011257,0.009365,0.015843,0.032238,0.031456,0.021141,0.00197,0.023695,0.000123,-0.007382,0.010934,-0.043818,-0.022649,0.038427,0.009473,-0.008978,0.034308,0.013288,0.01531,0.020739,-0.044199,0.032694,-0.015589,0.028608,0.043591,0.014018,0.008433,0.01826,1.0
10674,-0.014719,0.008035,-0.020909,0.01438,0.001948,0.024846,0.001286,0.035624,0.045272,0.042396,-0.044118,0.027642,-0.016702,-0.045047,0.03419,0.00695,0.073381,-0.015003,0.055612,0.035086,0.014825,0.026952,0.009799,-0.033874,0.076922,0.001576,0.046452,0.030098,0.001523,0.003636,0.016463,-0.022815,0.004967,-0.001326,0.055203,0.041632,-0.072908,-0.075853,0.063959,0.025538,-0.013595,0.032738,0.003776,-0.013634,-0.022674,-0.029323,-0.046717,-0.002111,-0.06112,0.079175,0.109619,-0.012267,0.077302,-0.037915,-0.004437,-0.001089,0.012456,-0.053455,-0.026448,-0.013132,0.057654,-0.125646,0.040757,-0.015899,1.0
11699,0.007227,-0.037689,-0.03123,-0.00083,-0.009916,0.014247,0.036252,0.029992,-0.014858,0.009602,0.026991,0.022914,0.003722,0.000956,0.010084,-0.046201,-0.040152,0.015242,0.00661,-0.004807,0.009134,0.019995,-0.038628,0.001548,0.010822,-0.027628,0.031518,-0.027189,-0.036847,-0.00041,-0.018229,-0.039499,-0.004898,0.000968,0.002646,-0.032414,-0.002216,0.008194,0.026798,0.001135,-0.020248,0.026219,-0.015044,0.005389,0.006827,-0.019934,0.040647,-0.007452,-0.037088,0.056974,0.00598,-0.065331,0.02824,0.008805,0.032377,0.002783,0.03078,0.016393,-0.033144,0.045921,0.043949,-0.011463,0.045163,-0.02275,1.0


In [78]:
# 少し時間がかかる
candidates, scores = model.recommend(
    unique_sids,
    sparse_item_user[unique_sids],
    N=cfg.exp.num_candidate,
    filter_already_liked_items=False,
)

In [89]:
candidate_score_df = pl.DataFrame(
    {
        "session_id": unique_session_ids,  # unique_sids と同じ順番
        "candidates": candidates,
        "scores": scores,
    }
)

In [90]:
candidate_score_df.head()

session_id,candidates,scores
str,list[i32],list[f32]
"""00026fd325b5d6…","[6182, 10610, … 4771]","[0.490051, 0.360482, … 0.261777]"
"""001419ed5ca0e6…","[12298, 10610, … 2620]","[0.516867, 0.385657, … 0.267045]"
"""00177ad4bf0130…","[10610, 4807, … 456]","[0.360191, 0.343715, … 0.261102]"
"""001d2bfb0608cf…","[10610, 10573, … 468]","[0.359624, 0.35897, … 0.262962]"
"""001e016195c0da…","[8134, 11181, … 7615]","[0.739138, 0.68611, … 0.285167]"


In [87]:
train_session_df

session_id
str
"""000007603d533d…"
"""0000ca043ed437…"
"""0000d4835cf113…"
"""0000fcda1ae1b2…"
"""000104bdffaaad…"
"""00011afe25c343…"
"""000125c737df18…"
"""0001763050a10b…"
"""000178c4d4d567…"
"""0001e6a407a85d…"


In [83]:
with utils.timer("load session data"):
    train_session_df = load_session_data(Path(cfg.dir.data_dir), "train")
    test_session_df = load_session_data(Path(cfg.dir.data_dir), "test")

[load session data] done in 0.0 s


In [91]:
train_candidate_df = train_session_df.join(
    candidate_score_df, on="session_id", how="left"
)

In [5]:
yad_factor_df = pl.read_parquet(
    "/kaggle/working/output/cand_unsupervised/session2item/bpr001/yad_factor.parquet"
)

In [6]:
yad_factor_df.shape

(463398, 18)

In [7]:
yad_factor_df

yad_no,yad_factor_0,yad_factor_1,yad_factor_2,yad_factor_3,yad_factor_4,yad_factor_5,yad_factor_6,yad_factor_7,yad_factor_8,yad_factor_9,yad_factor_10,yad_factor_11,yad_factor_12,yad_factor_13,yad_factor_14,yad_factor_15,yad_factor_16
i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
4177,0.074024,-0.034346,0.099663,0.055028,0.015707,-0.052419,-0.110205,0.053058,0.102324,-0.100173,0.115849,0.006051,-0.02324,-0.021158,0.027133,0.017828,1.0
712,-0.172432,-0.212058,-0.034804,0.023406,0.312044,0.017801,-0.056954,-0.367061,0.063719,0.098617,0.185253,-0.051754,-0.070032,-0.167549,-0.117107,-0.046513,1.0
9015,-0.117114,-0.026131,0.171715,-0.185116,0.194189,-0.143954,0.257885,0.007382,0.025175,-0.123284,0.127897,-0.058028,-0.041649,-0.023881,-0.144913,-0.080369,1.0
2570,0.045188,-0.012865,-0.010147,0.024632,0.042534,-0.027332,0.032062,0.041003,0.029313,-0.035977,-0.029259,-0.009981,-0.020431,-0.036406,0.070147,-0.057943,1.0
3485,0.137288,0.217753,0.001529,-0.104046,0.065735,0.09026,-0.017501,-0.121127,-0.123931,0.208412,0.068924,0.009135,0.052652,-0.104626,-0.045897,0.113347,1.0
3497,0.0734,0.051952,0.053088,-0.102765,0.13116,0.263129,0.090745,0.073372,-0.124377,-0.095768,0.14055,0.093831,-0.012505,0.076903,-0.174174,-0.160334,1.0
9977,-0.06722,-0.02374,-0.088081,-0.048094,-0.234928,0.105979,-0.101131,-0.07363,0.023331,-0.104018,0.185704,0.141995,0.315559,-0.054656,0.242043,0.161858,1.0
13067,0.060777,0.050222,-0.043942,0.042227,0.090325,0.021211,0.025697,-0.134353,0.015768,-0.084047,-0.033364,-0.109529,0.008306,-0.060011,-0.06808,-0.07319,1.0
13590,-0.02714,0.120466,0.108406,-0.09344,0.156054,0.039846,0.011179,0.308337,-0.103004,-0.231315,0.043372,0.231533,0.055855,0.188466,0.009308,0.11518,1.0
5800,0.099893,-0.139742,-0.101452,-0.016452,-0.123401,0.098688,-0.171821,0.289418,-0.034913,-0.058958,0.01424,-0.092762,-0.201267,-0.127277,-0.022915,-0.176774,1.0
