In [5]:
import os
import pickle
import random
import sys
import uuid
from pathlib import Path

import implicit
import lightgbm as lgb
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix, random
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

# 最大表示列数の指定（ここでは50列を指定）
pd.set_option("display.max_columns", 50)

In [3]:
from hydra import compose, initialize

with initialize(config_path="../yamls", version_base=None):
    config = compose(config_name="config.yaml")

In [4]:
train_df = pd.read_csv(Path(config.input_path) / "train.csv")
test_df = pd.read_csv(Path(config.input_path) / "test.csv")

sample_submission_df = pd.read_csv(Path(config.input_path) / "sample_submission.csv")
anime_df = pd.read_csv(Path(config.input_path) / "anime.csv")

# 整形
anime_df["genres"] = anime_df["genres"].str.replace(" ", "")

# Merge the train data with the anime meta data
all_df = pd.concat([train_df, test_df])
all_df = all_df.merge(anime_df, on="anime_id", how="left")

In [8]:
# userとitemのIDをマッピング
user_id_mapping = {id: i for i, id in enumerate(all_df["user_id"].unique())}
anime_id_mapping = {id: i for i, id in enumerate(all_df["anime_id"].unique())}

# マッピングをデータフレームに適用
all_df["user_label"] = all_df["user_id"].map(user_id_mapping)
all_df["anime_label"] = all_df["anime_id"].map(anime_id_mapping)

# スパースマトリックスを作成
item_user_data = csr_matrix((np.ones(len(all_df)), (all_df["user_label"], all_df["anime_label"])))

In [76]:
model = implicit.bpr.BayesianPersonalizedRanking(factors=64)
model.fit(item_user_data)

  0%|          | 0/100 [00:00<?, ?it/s]

In [75]:
model.user_factors

Matrix([[-0.02247773  0.5257393  -0.26336163 ... -0.3229936   0.13023865
   1.2687534 ]
 [-0.46913132  0.28219548 -0.09577905 ... -0.40073723  0.33726084
   0.34613404]
 [-0.27175185 -0.24868305  0.02096217 ... -0.19495237  0.28704846
  -0.35636327]
 ...
 [-0.14548008  0.82067573  0.11754293 ...  0.01751893 -0.45704842
  -0.52017367]
 [-0.04495116 -0.15110163 -0.00816941 ...  0.06239529  0.11036722
   0.38465363]
 [ 0.38145843  0.04694916  0.04290849 ...  0.3505601   0.7687281
  -0.7389714 ]])

In [22]:
model = implicit.als.AlternatingLeastSquares(factors=64)

In [23]:
# モデルの学習
model.fit(item_user_data)

  0%|          | 0/15 [00:00<?, ?it/s]

In [60]:
embeddings = np.concatenate(
    (user_factors[all_df["user_label"]].to_numpy(), item_factors[all_df["anime_label"]].to_numpy()), axis=1
)

In [62]:
embeddings_df = pd.DataFrame(embeddings)
embeddings_df.columns = [f"user_factor_{i}" for i in range(user_factors.shape[1])] + [
    f"item_factor_{j}" for j in range(item_factors.shape[1])
]
embeddings_df

Unnamed: 0,user_factor_0,user_factor_1,user_factor_2,user_factor_3,user_factor_4,user_factor_5,user_factor_6,user_factor_7,user_factor_8,user_factor_9,user_factor_10,user_factor_11,user_factor_12,user_factor_13,user_factor_14,user_factor_15,user_factor_16,user_factor_17,user_factor_18,user_factor_19,user_factor_20,user_factor_21,user_factor_22,user_factor_23,user_factor_24,...,item_factor_39,item_factor_40,item_factor_41,item_factor_42,item_factor_43,item_factor_44,item_factor_45,item_factor_46,item_factor_47,item_factor_48,item_factor_49,item_factor_50,item_factor_51,item_factor_52,item_factor_53,item_factor_54,item_factor_55,item_factor_56,item_factor_57,item_factor_58,item_factor_59,item_factor_60,item_factor_61,item_factor_62,item_factor_63
0,-0.400157,0.576584,-0.320363,0.163411,0.179890,0.710548,0.275089,1.350491,-0.551021,0.570927,-0.840700,-1.320849,-1.043444,-0.830262,0.462310,0.545076,0.956000,-1.473941,1.359565,-0.499627,0.098780,-0.045765,-1.382141,0.265037,-0.223404,...,0.003759,0.009497,-0.005083,-0.023899,-0.004393,-0.038294,0.001070,0.018924,0.009528,0.027943,-0.003546,0.034754,0.030214,0.070854,0.019966,-0.032163,0.006835,0.039944,-0.012004,-0.001277,-0.009701,0.011147,-0.029691,0.001297,0.014161
1,-0.400157,0.576584,-0.320363,0.163411,0.179890,0.710548,0.275089,1.350491,-0.551021,0.570927,-0.840700,-1.320849,-1.043444,-0.830262,0.462310,0.545076,0.956000,-1.473941,1.359565,-0.499627,0.098780,-0.045765,-1.382141,0.265037,-0.223404,...,0.026397,-0.030462,0.026448,-0.041479,0.012010,-0.000467,-0.017531,-0.001195,-0.014470,-0.010783,-0.028763,0.013391,-0.023341,0.047001,-0.020750,-0.021187,0.007525,-0.010050,-0.023223,-0.012739,-0.006968,-0.007458,0.022049,-0.003509,0.004716
2,-0.400157,0.576584,-0.320363,0.163411,0.179890,0.710548,0.275089,1.350491,-0.551021,0.570927,-0.840700,-1.320849,-1.043444,-0.830262,0.462310,0.545076,0.956000,-1.473941,1.359565,-0.499627,0.098780,-0.045765,-1.382141,0.265037,-0.223404,...,0.020792,-0.005928,-0.021606,0.010971,-0.003663,0.001791,0.015570,0.007787,0.002330,-0.023625,-0.004566,0.007267,-0.001379,0.017787,0.008949,-0.012457,0.017698,-0.024005,0.017076,-0.006175,0.015813,-0.002343,-0.007575,-0.001028,0.003198
3,-0.400157,0.576584,-0.320363,0.163411,0.179890,0.710548,0.275089,1.350491,-0.551021,0.570927,-0.840700,-1.320849,-1.043444,-0.830262,0.462310,0.545076,0.956000,-1.473941,1.359565,-0.499627,0.098780,-0.045765,-1.382141,0.265037,-0.223404,...,-0.016766,0.004325,0.042523,-0.055396,0.032696,-0.029555,-0.025348,-0.051389,0.040119,-0.016376,-0.042023,0.014422,-0.036973,0.045541,0.049040,-0.026662,0.025756,0.057869,-0.028772,-0.004672,0.000861,-0.009265,-0.008501,0.018607,0.025384
4,-0.400157,0.576584,-0.320363,0.163411,0.179890,0.710548,0.275089,1.350491,-0.551021,0.570927,-0.840700,-1.320849,-1.043444,-0.830262,0.462310,0.545076,0.956000,-1.473941,1.359565,-0.499627,0.098780,-0.045765,-1.382141,0.265037,-0.223404,...,-0.018594,-0.010164,-0.000902,0.000911,-0.017389,0.014148,-0.006110,0.027536,0.006122,-0.017073,-0.004127,0.008139,0.011862,0.007589,-0.001428,-0.040582,-0.019582,-0.007315,-0.000653,-0.005088,-0.011372,-0.009883,-0.006204,-0.003889,-0.003247
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
254072,-0.417092,0.280127,0.189547,-0.221137,-0.400826,-0.579004,0.186541,1.180578,1.065188,0.401765,0.034966,0.284532,0.847514,-0.847343,-0.706831,-0.166989,-0.355063,-0.678944,-0.018316,-0.697292,0.584756,1.390854,-0.466070,-0.381619,0.967749,...,0.035328,-0.047377,-0.006039,0.025115,0.019244,-0.030323,-0.073078,0.006942,-0.064463,-0.014478,-0.037610,0.052859,0.015897,-0.046784,-0.003840,0.023055,0.002536,-0.045659,0.031457,-0.026724,0.014007,0.003064,0.011883,-0.006431,0.031380
254073,-0.417092,0.280127,0.189547,-0.221137,-0.400826,-0.579004,0.186541,1.180578,1.065188,0.401765,0.034966,0.284532,0.847514,-0.847343,-0.706831,-0.166989,-0.355063,-0.678944,-0.018316,-0.697292,0.584756,1.390854,-0.466070,-0.381619,0.967749,...,-0.008235,0.008715,-0.015733,-0.012648,-0.012698,0.014614,-0.009035,-0.006066,-0.020445,0.015588,-0.001055,-0.023811,-0.019366,-0.011829,0.006955,-0.007345,-0.009650,0.011074,-0.000533,0.012930,-0.008814,-0.009870,-0.010268,-0.011969,-0.009001
254074,-0.417092,0.280127,0.189547,-0.221137,-0.400826,-0.579004,0.186541,1.180578,1.065188,0.401765,0.034966,0.284532,0.847514,-0.847343,-0.706831,-0.166989,-0.355063,-0.678944,-0.018316,-0.697292,0.584756,1.390854,-0.466070,-0.381619,0.967749,...,0.017456,0.045535,-0.001329,-0.017803,-0.016746,0.019120,-0.024763,0.002361,-0.038948,0.017218,-0.020827,-0.032460,0.004236,-0.021754,-0.007122,-0.031132,-0.010945,0.033075,0.014512,0.004835,-0.018377,0.009043,0.040812,0.022192,-0.017618
254075,-0.417092,0.280127,0.189547,-0.221137,-0.400826,-0.579004,0.186541,1.180578,1.065188,0.401765,0.034966,0.284532,0.847514,-0.847343,-0.706831,-0.166989,-0.355063,-0.678944,-0.018316,-0.697292,0.584756,1.390854,-0.466070,-0.381619,0.967749,...,0.009653,0.013932,-0.015024,-0.000960,0.000111,0.012100,-0.005019,-0.025667,-0.030151,-0.002133,0.008734,-0.004045,-0.012580,-0.011190,0.014284,-0.015684,0.001997,0.034950,0.009947,-0.025895,0.016666,-0.010076,-0.014206,0.023812,-0.005617
