In [1]:
import os
import pickle
import random
import sys
import uuid
from pathlib import Path

import implicit
import lightgbm as lgb
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix, random
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

# 最大表示列数の指定（ここでは50列を指定）
pd.set_option("display.max_columns", 50)

sys.path.append(os.pardir)
from hydra import compose, initialize

from utils import load_datasets
from utils.embedding import TextEmbedder

with initialize(config_path="../yamls", version_base=None):
    config = compose(config_name="config.yaml")


train_df = pd.read_csv(Path(config.input_path) / "train.csv")
test_df = pd.read_csv(Path(config.input_path) / "test.csv")

sample_submission_df = pd.read_csv(Path(config.input_path) / "sample_submission.csv")
anime_df = pd.read_csv(Path(config.input_path) / "anime.csv")

# 整形
anime_df["genres"] = anime_df["genres"].str.replace(" ", "")

# Merge the train data with the anime meta data
all_df = pd.concat([train_df, test_df])
all_df = all_df.merge(anime_df, on="anime_id", how="left")

In [8]:
user_num_cols = ["members", "watching", "completed", "on_hold", "dropped", "plan_to_watch"]

df = all_df[user_num_cols + ["user_id", "anime_id"]].copy()

In [9]:
df

Unnamed: 0,members,watching,completed,on_hold,dropped,plan_to_watch,user_id,anime_id
0,542642,64809,383733,10625,5735,77740,0008e10fb39e55447333,0669cc0219d468761195
1,650309,29665,477257,13336,18054,111997,0008e10fb39e55447333,111adb8835b8a1a2cf54
2,137560,5153,113190,758,9431,9028,0008e10fb39e55447333,1fc8683c393432a2f9c7
3,1255830,68041,942402,26125,19213,200049,0008e10fb39e55447333,2290175205d55e81b197
4,97346,1565,82189,502,379,12711,0008e10fb39e55447333,28f173b60331d5cabb0d
...,...,...,...,...,...,...,...,...
254072,1160651,66549,815938,35566,20358,222240,ffe85a36cd20500faa58,f508b02efeac8ecb8cc0
254073,152465,6563,92215,6575,8356,38756,ffe85a36cd20500faa58,f5b8ecea3beea4b82d79
254074,375013,16267,261233,12050,15402,70061,ffe85a36cd20500faa58,f6c208226b6b69948053
254075,85634,3675,58233,3065,4208,16453,ffe85a36cd20500faa58,fe67592c312fc1e17745


In [10]:
use_cols = user_num_cols.copy()
for col in user_num_cols:
    if col != "members":
        new_col = f"{col}_norm"
        use_cols.append(new_col)
        df[new_col] = df[col] / df["members"]

In [11]:
df

Unnamed: 0,members,watching,completed,on_hold,dropped,plan_to_watch,user_id,anime_id,watching_norm,completed_norm,on_hold_norm,dropped_norm,plan_to_watch_norm
0,542642,64809,383733,10625,5735,77740,0008e10fb39e55447333,0669cc0219d468761195,0.119432,0.707157,0.019580,0.010569,0.143262
1,650309,29665,477257,13336,18054,111997,0008e10fb39e55447333,111adb8835b8a1a2cf54,0.045617,0.733893,0.020507,0.027762,0.172221
2,137560,5153,113190,758,9431,9028,0008e10fb39e55447333,1fc8683c393432a2f9c7,0.037460,0.822841,0.005510,0.068559,0.065630
3,1255830,68041,942402,26125,19213,200049,0008e10fb39e55447333,2290175205d55e81b197,0.054180,0.750422,0.020803,0.015299,0.159296
4,97346,1565,82189,502,379,12711,0008e10fb39e55447333,28f173b60331d5cabb0d,0.016077,0.844298,0.005157,0.003893,0.130575
...,...,...,...,...,...,...,...,...,...,...,...,...,...
254072,1160651,66549,815938,35566,20358,222240,ffe85a36cd20500faa58,f508b02efeac8ecb8cc0,0.057338,0.703000,0.030643,0.017540,0.191479
254073,152465,6563,92215,6575,8356,38756,ffe85a36cd20500faa58,f5b8ecea3beea4b82d79,0.043046,0.604827,0.043125,0.054806,0.254196
254074,375013,16267,261233,12050,15402,70061,ffe85a36cd20500faa58,f6c208226b6b69948053,0.043377,0.696597,0.032132,0.041071,0.186823
254075,85634,3675,58233,3065,4208,16453,ffe85a36cd20500faa58,fe67592c312fc1e17745,0.042915,0.680022,0.035792,0.049139,0.192132


In [19]:
# userごとに特徴量を集約する(mean, max, min, sum)
user_stats = df[use_cols + ["user_id"]].groupby("user_id").agg(["mean", "max", "min", "sum"])
user_stats_columns = ["_".join(col).strip() for col in user_stats.columns.values]
user_stats.columns = user_stats_columns
user_stats.reset_index(inplace=True)

In [20]:
user_stats

Unnamed: 0,user_id,members_mean,members_max,members_min,members_sum,watching_mean,watching_max,watching_min,watching_sum,completed_mean,completed_max,completed_min,completed_sum,on_hold_mean,on_hold_max,on_hold_min,on_hold_sum,dropped_mean,dropped_max,dropped_min,dropped_sum,plan_to_watch_mean,plan_to_watch_max,plan_to_watch_min,plan_to_watch_sum,watching_norm_mean,watching_norm_max,watching_norm_min,watching_norm_sum,completed_norm_mean,completed_norm_max,completed_norm_min,completed_norm_sum,on_hold_norm_mean,on_hold_norm_max,on_hold_norm_min,on_hold_norm_sum,dropped_norm_mean,dropped_norm_max,dropped_norm_min,dropped_norm_sum,plan_to_watch_norm_mean,plan_to_watch_norm_max,plan_to_watch_norm_min,plan_to_watch_norm_sum
0,0008e10fb39e55447333,649808.705882,2589552,61808,44186992,34766.867647,171871,342,2364147,501709.294118,2182587,29240,34116232,13660.176471,87145,169,928892,12435.602941,90661,170,845621,87236.764706,329800,3291,5932100,0.049085,0.198173,0.005533,3.337785,0.729687,0.935736,0.305139,49.618734,0.020194,0.115484,0.002508,1.373177,0.020706,0.141399,0.001118,1.407981,0.180328,0.398138,0.039980,12.262323
1,001a7aed2546342e2602,405477.358156,2531397,61287,114344615,18052.769504,140753,938,5090881,292500.354610,2182587,28319,82485100,9479.443262,62664,426,2673203,11014.382979,90661,247,3106056,74430.407801,274277,10587,20989375,0.044693,0.118363,0.009023,12.603299,0.677189,0.902819,0.370414,190.967328,0.025854,0.073828,0.002508,7.290719,0.033090,0.141399,0.001118,9.331465,0.219174,0.460814,0.049805,61.807188
2,003d4b0257cc7849ffe1,403922.542373,1830540,63739,23831430,16583.406780,137167,1561,978421,300110.966102,1462223,41504,17706547,11057.372881,61734,630,652385,11765.118644,99806,329,694142,64405.677966,247847,7522,3799935,0.038227,0.099959,0.009023,2.255371,0.729051,0.904573,0.439820,43.013995,0.028130,0.092585,0.002508,1.659666,0.028816,0.080337,0.001118,1.700117,0.175777,0.431117,0.038027,10.370851
3,0054e700b5be6e074fb7,831010.363636,1830540,90794,9141114,41445.818182,137167,2897,455904,580483.909091,1462223,60971,6385323,26775.454545,62664,2794,294530,26652.000000,99806,2891,293172,155653.181818,274277,21241,1712185,0.042262,0.074933,0.018018,0.464882,0.693068,0.798793,0.586011,7.623744,0.028845,0.056390,0.009004,0.317291,0.027571,0.054523,0.006586,0.303283,0.208255,0.296958,0.038027,2.290800
4,0059344eed7e8ca0b6c5,336418.705882,1187921,63912,5719118,10896.764706,36600,993,185245,244391.000000,885356,44740,4154647,6956.294118,27617,536,118257,9443.470588,40191,346,160539,64731.176471,198157,7254,1100430,0.029191,0.050178,0.010975,0.496251,0.737793,0.868331,0.597632,12.542482,0.016648,0.039408,0.005782,0.283008,0.022434,0.046689,0.003459,0.381384,0.193934,0.325531,0.103618,3.296875
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1993,feef23df0d53eec7d697,679106.000000,1533289,340561,3395530,33290.600000,66219,15435,166453,484630.800000,1154210,199990,2423154,18174.200000,38657,9400,90871,23015.200000,55337,12473,115076,119995.200000,218866,78815,599976,0.050427,0.060830,0.039164,0.252133,0.682806,0.752767,0.587237,3.414030,0.028176,0.039432,0.023851,0.140882,0.034948,0.043023,0.023171,0.174738,0.203643,0.270040,0.142743,1.018217
1994,ff441af085c3522f62ba,438119.791908,2589552,60101,75794724,20551.768786,140753,342,3555456,308853.670520,2182587,30936,53431685,13629.693642,77117,169,2357937,14406.549133,99806,170,2492333,80678.109827,319373,3291,13957313,0.044454,0.102736,0.004896,7.690619,0.659878,0.973677,0.382085,114.158862,0.032585,0.115723,0.001260,5.637134,0.036797,0.137457,0.001118,6.365956,0.226286,0.447514,0.018179,39.147429
1995,ff5e8e9e3553b90f222a,619496.060000,2589552,68422,30974803,27442.140000,140753,2084,1372107,459571.800000,2182587,25057,22978590,15356.520000,75054,1164,767826,17441.140000,80834,414,872057,99684.460000,235032,13839,4984223,0.044619,0.156444,0.010305,2.230931,0.684446,0.890699,0.333359,34.222303,0.027935,0.100147,0.002694,1.396762,0.033878,0.130002,0.001118,1.693883,0.209122,0.491516,0.049805,10.456122
1996,ffa6ff8006f8630f3d11,653341.371429,2531397,68091,45733896,29509.071429,140753,1082,2065635,494093.828571,2182587,57869,34586568,14152.871429,77117,704,990701,15887.042857,90661,329,1112093,99698.557143,253277,7385,6978899,0.046450,0.205876,0.009023,3.251473,0.723521,0.902819,0.400914,50.646435,0.023339,0.063292,0.002508,1.633732,0.025599,0.077062,0.001118,1.791958,0.181091,0.372305,0.049805,12.676402


Unnamed: 0,members_mean_mean,members_mean_max,members_mean_min,members_mean_sum,members_max_mean,members_max_max,members_max_min,members_max_sum,members_min_mean,members_min_max,members_min_min,members_min_sum,members_sum_mean,members_sum_max,members_sum_min,members_sum_sum,watching_mean_mean,watching_mean_max,watching_mean_min,watching_mean_sum,watching_max_mean,watching_max_max,watching_max_min,watching_max_sum,watching_min_mean,...,dropped_norm_max_sum,dropped_norm_min_mean,dropped_norm_min_max,dropped_norm_min_min,dropped_norm_min_sum,dropped_norm_sum_mean,dropped_norm_sum_max,dropped_norm_sum_min,dropped_norm_sum_sum,plan_to_watch_norm_mean_mean,plan_to_watch_norm_mean_max,plan_to_watch_norm_mean_min,plan_to_watch_norm_mean_sum,plan_to_watch_norm_max_mean,plan_to_watch_norm_max_max,plan_to_watch_norm_max_min,plan_to_watch_norm_max_sum,plan_to_watch_norm_min_mean,plan_to_watch_norm_min_max,plan_to_watch_norm_min_min,plan_to_watch_norm_min_sum,plan_to_watch_norm_sum_mean,plan_to_watch_norm_sum_max,plan_to_watch_norm_sum_min,plan_to_watch_norm_sum_sum
0,633168.397387,1.298835e+06,334091.922680,1.272668e+08,2.546199e+06,2589552,1751054,511785919,70990.417910,542642,60101,14269074,1.179488e+08,239477232,10765595,23707714197,36727.856318,70935.416667,15631.448454,7.382299e+06,265907.995025,887333,90902,53447507,1172.119403,...,27.048278,0.001328,0.004931,0.001108,0.266986,5.607371,19.218431,0.129252,1127.081652,0.196053,0.271878,0.104574,39.406732,0.512026,0.854998,0.234862,102.917304,0.040176,0.125407,0.018179,8.075393,41.361792,109.759605,1.254891,8313.720145
1,623091.446973,1.137489e+06,321157.119617,1.719732e+08,2.542661e+06,2589552,1255830,701774360,68380.000000,169402,60101,18872880,1.132346e+08,247171836,9252341,31252751890,34619.969418,72434.796296,13842.435407,9.555112e+06,248822.394928,362124,78900,68674981,1080.742754,...,36.732455,0.001302,0.005017,0.001108,0.359445,5.577846,20.035034,0.444660,1539.485511,0.202187,0.293746,0.125788,55.803663,0.518538,0.854998,0.302896,143.116493,0.040028,0.122872,0.018179,11.047793,41.724793,125.255934,2.417608,11516.042993
2,547446.043639,1.137489e+06,226427.222222,5.748183e+07,2.478708e+06,2589552,482114,260264353,67043.323810,137560,60101,7039549,9.686452e+07,239477232,885320,10170774601,28634.344315,72434.796296,10966.456790,3.006606e+06,229413.895238,362124,28272,24088459,994.400000,...,13.953543,0.001419,0.012610,0.001108,0.148969,5.540018,16.878166,0.109216,581.701846,0.198575,0.329361,0.132386,20.850400,0.484337,0.838440,0.301385,50.855341,0.038031,0.065630,0.018179,3.993260,37.731704,107.913887,0.464829,3961.828945
3,630740.199411,1.328354e+06,334091.922680,3.260927e+08,2.541124e+06,2589552,1255830,1313760869,70957.941973,597674,60101,36685256,1.172443e+08,247171836,5441982,60615320350,35250.591361,95844.142857,15631.448454,1.822456e+07,255722.435203,887333,69924,132208499,1186.193424,...,69.319714,0.001397,0.008542,0.001108,0.722237,5.845833,21.127677,0.193782,3022.295445,0.199150,0.293206,0.108882,102.960802,0.499254,0.854998,0.211691,258.114422,0.039898,0.122872,0.018179,20.627511,42.402333,125.255934,1.356040,21922.006280
4,525115.026193,7.834395e+05,303694.631579,3.623294e+07,2.512315e+06,2589552,1251960,173349715,63659.855072,78073,60199,4392530,1.013711e+08,233468842,17316816,6994607678,27947.148343,48760.466258,12142.602871,1.928353e+06,243126.260870,362124,78950,16775712,692.362319,...,8.627648,0.001352,0.002850,0.001118,0.093273,5.630087,19.460402,1.155057,388.475995,0.204616,0.276994,0.108544,14.118474,0.496563,0.802440,0.282893,34.262839,0.037716,0.084217,0.018179,2.602435,41.829153,125.255934,6.610909,2886.211576
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
254072,544662.416781,1.254089e+06,262657.640523,3.257081e+08,2.487613e+06,2589552,1160651,1487592500,67138.571906,479832,60101,40148866,9.536879e+07,247171836,2997519,57030534781,27906.670354,71813.184211,7831.320755,1.668819e+07,212017.807692,362124,66549,126786649,951.595318,...,72.235184,0.001453,0.017540,0.001108,0.869189,5.206983,20.035034,0.082669,3113.775927,0.210020,0.304601,0.106177,125.591842,0.503466,0.854998,0.244234,301.072867,0.040804,0.191479,0.018179,24.401012,39.716546,122.669225,0.820551,23750.494423
254073,464050.616033,7.834395e+05,302165.314465,3.990835e+07,2.520348e+06,2589552,1726660,216749904,62778.593023,84272,60101,5398959,9.179121e+07,239477232,16471938,7894044392,23528.230429,49594.140000,12142.602871,2.023428e+06,211237.209302,362124,72442,18166400,785.802326,...,11.163646,0.001775,0.009826,0.001118,0.152662,6.714559,18.811469,0.656074,577.452083,0.211165,0.285407,0.127745,18.160214,0.490160,0.597713,0.328818,42.153801,0.037896,0.084217,0.018179,3.259015,43.642319,110.613422,4.215580,3753.239448
254074,498194.370271,9.303776e+05,289344.309859,1.091046e+08,2.513943e+06,2589552,1095634,550553483,65281.762557,166365,60101,14296706,1.045664e+08,242225705,5797033,22900041374,25184.911627,56340.709677,11178.324324,5.515496e+06,221613.178082,362124,49249,48533286,922.027397,...,29.051069,0.001561,0.010771,0.001108,0.341815,6.998555,19.218431,0.442166,1532.683588,0.204955,0.276994,0.132801,44.885117,0.478533,0.854998,0.246006,104.798702,0.039714,0.090578,0.018179,8.697446,45.053941,110.613422,2.583512,9866.812990
254075,439531.515973,1.076698e+06,290910.303030,3.296486e+07,2.445974e+06,2589552,1108591,183448053,63099.253333,85634,60101,4732444,8.748498e+07,200192846,10743491,6561373568,21127.987451,54709.560000,11178.324324,1.584599e+06,192033.466667,362124,49249,14402510,840.973333,...,9.393557,0.001852,0.010771,0.001118,0.138875,6.595842,17.804889,0.783665,494.688172,0.207970,0.259261,0.135844,15.597747,0.469863,0.854998,0.275630,35.239706,0.038858,0.109834,0.018179,2.914336,42.518929,109.639534,3.396106,3188.919706
