In [1]:
import os
import pickle
import random
import sys
import uuid
from pathlib import Path

import implicit
import lightgbm as lgb
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix, random
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

# 最大表示列数の指定（ここでは50列を指定）
pd.set_option("display.max_columns", 50)

In [2]:
from hydra import compose, initialize

with initialize(config_path="../yamls", version_base=None):
    config = compose(config_name="config.yaml")

In [18]:
train_df = pd.read_csv(Path(config.input_path) / "train.csv")
test_df = pd.read_csv(Path(config.input_path) / "test.csv")

sample_submission_df = pd.read_csv(Path(config.input_path) / "sample_submission.csv")
anime_df = pd.read_csv(Path(config.input_path) / "anime.csv")

# 整形
anime_df["genres"] = anime_df["genres"].str.replace(" ", "")

# Merge the train data with the anime meta data
all_df = pd.concat([train_df, test_df])
all_df = all_df.merge(anime_df, on="anime_id", how="left")

In [19]:
import cuml
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import MultiLabelBinarizer

In [38]:
multilabel_cols = ["genres", "producers", "licensors", "studios"]
multilabel_dfs = []

all_cols = []
for c in multilabel_cols:
    list_srs = anime_df[c].map(lambda x: x.split(",")).tolist()
    mlb = MultiLabelBinarizer()
    ohe_srs = mlb.fit_transform(list_srs)
    col_names = [f"mhe_{c}_{name}" for name in mlb.classes_]
    col_df = pd.DataFrame(ohe_srs, columns=col_names)
    all_cols += col_names
    multilabel_dfs.append(col_df)

multilabel_df = pd.concat(multilabel_dfs, axis=1)


n_components = 100
# ユニーク数が多いので、SVDで次元圧縮する
svd = cuml.TruncatedSVD(n_components=n_components)
svd_df = svd.fit_transform(multilabel_df.astype(float))
svd_df.columns = [f"svd_{ix}" for ix in range(n_components)]
svd_df

Unnamed: 0,svd_0,svd_1,svd_2,svd_3,svd_4,svd_5,svd_6,svd_7,svd_8,svd_9,svd_10,svd_11,svd_12,svd_13,svd_14,svd_15,svd_16,svd_17,svd_18,svd_19,svd_20,svd_21,svd_22,svd_23,svd_24,...,svd_75,svd_76,svd_77,svd_78,svd_79,svd_80,svd_81,svd_82,svd_83,svd_84,svd_85,svd_86,svd_87,svd_88,svd_89,svd_90,svd_91,svd_92,svd_93,svd_94,svd_95,svd_96,svd_97,svd_98,svd_99
0,1.125968,0.537574,0.303108,0.450072,-0.435339,-0.112942,-0.088667,0.813236,1.392803,-0.672237,-0.460651,0.243149,-0.165782,0.370291,0.363845,0.804028,0.393751,-0.181829,0.070837,-0.202997,0.145154,-0.258692,0.186200,0.325024,0.156588,...,-0.245437,0.023958,-0.023036,0.040660,0.137754,-0.130373,0.120417,0.172941,-0.113375,-0.179871,-0.020705,0.078527,0.062724,-0.050881,-0.097498,-0.247382,-0.097107,0.125160,-0.016909,0.153199,-0.055659,0.136429,0.048900,-0.028593,0.258843
1,1.273117,-0.644639,0.257210,0.038697,0.529087,0.293090,-1.284662,0.634612,-0.041634,0.576900,-0.099386,0.126467,-0.750616,0.529672,0.830686,0.235479,0.446934,-0.319099,0.064834,0.417913,0.122599,-0.032494,-0.030712,0.318232,-0.015503,...,-0.108949,-0.006638,0.040028,0.029339,-0.118082,-0.020382,-0.057419,0.123766,-0.065391,-0.173252,0.116513,0.136643,-0.045136,-0.059095,-0.291461,-0.165114,-0.021381,-0.039862,-0.245462,0.007444,-0.284799,-0.055078,0.004941,-0.113110,0.147748
2,1.193991,-0.536311,0.199199,1.269263,0.365038,0.487946,-1.227709,0.304122,0.662058,0.138300,-0.007904,0.315433,-0.625495,0.340051,0.844701,0.352524,0.543165,-0.435953,-0.042691,0.403155,-0.161400,0.062487,-0.261022,0.646092,-0.366104,...,-0.099905,-0.023757,-0.025118,-0.011762,-0.099159,0.050366,-0.082802,0.016708,0.066236,-0.103275,0.094815,0.018898,-0.033121,-0.111994,-0.219403,-0.067839,0.065894,-0.117554,-0.133468,0.028613,-0.233886,-0.073585,-0.102173,-0.196534,0.088071
3,1.405198,0.496818,-0.280110,-1.045505,0.299266,0.515405,-0.269061,0.011289,0.283431,0.454539,0.323017,0.472782,0.206079,-0.414064,-0.308780,-0.251482,0.195787,-0.327401,0.201104,-0.014216,0.033108,-0.401896,-0.202695,0.223224,-0.016682,...,0.331064,-0.356295,-0.189509,-0.433224,0.204377,-0.226221,-0.147335,0.056691,0.041705,0.038332,0.222119,0.262094,0.140063,-0.065547,0.277472,0.199398,0.083227,0.187308,0.386412,0.031037,0.074291,0.002605,0.033151,0.018574,-0.059669
4,1.548526,0.461096,-0.002371,-0.956949,-0.316338,-0.231009,0.076347,-0.654844,-0.244590,-0.641900,-0.814801,0.332265,0.188997,0.647829,0.094219,0.190069,0.299570,0.197155,0.264429,-0.198997,0.014886,-0.140021,0.098151,0.074409,-0.060787,...,-0.214225,0.121195,-0.295230,-0.011466,-0.137110,0.124184,0.190546,0.088402,0.276382,-0.232090,-0.088795,-0.171781,0.153269,0.117983,0.217357,-0.050122,0.058411,-0.031229,0.286129,-0.028112,0.151815,0.247568,-0.155562,-0.079232,0.231894
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,1.199449,-0.394725,1.319336,-0.561475,0.539240,-0.604012,0.337242,-0.435267,-0.094692,-0.391778,-0.661664,0.087713,-0.267872,0.505784,0.107179,-0.112913,0.176693,-0.226505,-0.523513,0.052946,0.359577,0.175284,-0.034261,0.183733,-0.084986,...,0.106527,0.097983,-0.040002,-0.262228,-0.046452,-0.163126,-0.013788,0.074171,0.028070,-0.023325,-0.237450,0.087101,0.132553,-0.070772,-0.204998,-0.066024,-0.111077,0.045921,-0.194292,-0.117283,0.011989,0.174452,-0.174795,-0.046798,0.164753
1996,0.784504,0.978392,0.388339,-0.012374,0.145536,0.677901,0.787277,-0.041633,-0.355933,-0.157014,0.328824,0.221666,-0.431718,-0.068087,0.242959,0.238493,0.113117,-0.129339,0.139969,-0.127598,0.130083,0.036021,0.163269,-0.148791,-0.186172,...,0.125429,-0.277730,0.080239,-0.116837,0.161472,0.207920,0.054492,-0.008049,-0.118404,-0.083271,-0.013655,-0.044500,-0.048886,0.019970,0.086899,0.174472,-0.228389,0.123415,-0.284776,-0.170204,-0.211894,-0.102104,-0.210701,0.431017,-0.108353
1997,0.946098,-0.997233,0.261124,0.085982,0.106847,0.182060,0.713756,0.026780,0.803796,-0.065660,0.099529,0.050898,0.120896,0.381654,-0.486037,-0.005638,0.652581,0.271213,-0.095213,0.553627,-0.812732,-0.250127,-0.397724,-0.327242,-0.233250,...,0.260465,-0.147963,-0.005950,-0.182957,-0.097227,0.016927,-0.056804,-0.119367,-0.111422,0.078680,0.017874,-0.016356,0.004813,0.031758,0.090270,-0.117736,0.065828,-0.015405,0.121927,-0.189413,0.003075,0.037729,-0.120091,-0.064527,0.039231
1998,1.449392,-0.904871,0.113465,-0.762358,0.360926,0.057602,0.371051,-0.069411,0.417296,0.107439,0.464277,0.189779,-0.718513,0.308131,-0.057838,-0.484183,0.333297,-0.421464,-0.059161,-0.139915,0.180111,0.142297,-0.069739,0.178520,0.258587,...,0.059036,-0.068067,0.046482,0.055409,0.039237,-0.042497,-0.060220,0.107416,-0.103874,0.031470,-0.017790,-0.066207,-0.092907,0.011587,-0.022760,0.039564,0.033185,0.008448,0.084291,-0.073751,0.015774,0.032316,-0.001106,-0.004961,0.032913


In [39]:
svd_df["anime_id"] = anime_df["anime_id"]

In [40]:
df = all_df[["anime_id"]].copy()
df = df.merge(svd_df, on="anime_id", how="left")
df = df.drop(["anime_id"], axis=1)
df

Unnamed: 0,svd_0,svd_1,svd_2,svd_3,svd_4,svd_5,svd_6,svd_7,svd_8,svd_9,svd_10,svd_11,svd_12,svd_13,svd_14,svd_15,svd_16,svd_17,svd_18,svd_19,svd_20,svd_21,svd_22,svd_23,svd_24,...,svd_75,svd_76,svd_77,svd_78,svd_79,svd_80,svd_81,svd_82,svd_83,svd_84,svd_85,svd_86,svd_87,svd_88,svd_89,svd_90,svd_91,svd_92,svd_93,svd_94,svd_95,svd_96,svd_97,svd_98,svd_99
0,0.969073,-0.802336,-0.617044,0.064773,-0.295588,-0.266226,0.312219,-0.018586,-0.500382,-0.523972,0.352309,-0.241894,0.216569,0.066247,0.770089,-0.234870,-0.175899,-0.725145,-0.281237,-0.141788,-0.468867,0.562902,-0.030012,-0.281524,0.215246,...,-0.169903,-0.098534,-0.021072,0.016028,0.077593,-0.190504,-0.002950,-0.156650,-0.116898,-0.021005,0.161586,-0.126147,0.098815,-0.048694,-0.039323,0.221118,0.057496,0.097266,-0.000069,-0.081120,-0.040543,-0.029439,0.030050,-0.180508,0.199374
1,0.918463,-0.599226,0.011068,1.119853,-0.649145,0.289067,-0.032645,-0.199284,0.012741,0.843803,-0.219486,-0.310036,0.006945,0.173255,-0.125063,-0.103542,0.057784,-0.181546,-0.112044,-0.035803,0.397814,-0.165278,-0.095916,-0.247082,0.123367,...,-0.071519,0.088022,-0.094032,-0.054688,-0.020085,-0.103920,0.151998,0.001529,0.023619,-0.107669,0.004437,0.027431,-0.077152,-0.051297,0.068957,0.055219,-0.171513,0.063922,-0.028814,0.166621,-0.052952,-0.168373,-0.047846,-0.063038,-0.019711
2,0.154275,0.096763,-0.205458,0.292924,0.008156,-0.341503,0.202374,-0.123738,0.075898,0.630773,-0.058460,0.101421,0.116606,0.092648,0.125548,-0.041882,-0.059426,0.088081,-0.019211,0.042894,0.220442,-0.025569,0.038155,-0.030263,0.014744,...,0.018308,0.002198,-0.028322,0.015506,-0.031202,0.012941,0.022335,-0.021387,-0.000746,-0.041943,0.013952,0.009168,0.045633,-0.020735,0.022186,-0.050391,0.020611,-0.003364,0.009163,-0.038905,-0.017039,0.011455,-0.008495,0.027896,0.006279
3,1.570335,-0.168720,0.045993,-0.449586,-0.594992,0.287230,-0.803906,0.617122,0.107206,0.508862,-0.228708,-0.741056,0.352703,0.316645,-0.113151,0.033809,-0.373382,-0.269880,-0.317106,0.070421,-0.082457,0.494435,-0.177161,-0.213916,0.323219,...,-0.214995,-0.239038,-0.323444,-0.001848,0.411781,0.391616,0.193976,-0.020713,0.116066,0.021772,-0.013149,-0.267133,0.070892,0.109408,-0.177313,0.085304,-0.069473,0.015018,-0.094289,0.031526,0.081405,-0.162488,0.120016,-0.115374,0.027159
4,0.863517,0.318579,-0.758036,-0.106236,-0.264214,-0.586377,-0.158922,0.024884,-0.303881,-0.330588,-0.289409,0.327747,-0.235599,-0.275490,-0.275357,0.381853,-0.175035,0.098524,-0.292543,-0.312439,-0.486107,-0.325869,0.080402,-0.276150,-0.203875,...,-0.215884,0.414746,-0.011005,-0.168932,-0.209648,0.045791,0.122723,0.084773,0.118239,-0.051075,0.015373,-0.005606,-0.084287,0.030736,0.105027,-0.044233,-0.008342,-0.012048,0.115305,-0.102302,-0.110508,-0.051766,0.026873,-0.161425,0.042201
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
254072,0.929004,-0.707761,0.996715,0.092789,-0.289466,-0.755036,0.748660,0.279070,0.000030,-0.191936,-0.428843,0.079946,0.467182,0.096589,-0.380632,-0.370252,0.287014,-0.103242,0.650993,-0.466046,-0.188904,0.085846,-0.148498,-0.060742,0.175597,...,0.344372,-0.251023,0.194437,0.256132,0.064150,-0.166225,0.156734,-0.029316,0.062565,0.052205,-0.023970,0.265333,0.207039,0.047307,-0.185227,0.088872,-0.038661,-0.238733,-0.237439,-0.082261,0.175621,-0.086242,-0.319020,-0.156333,-0.034428
254073,1.871571,-0.172797,-0.885135,0.306050,-0.515917,0.252627,-0.015314,0.060014,-0.592835,-0.235812,0.474439,-0.224653,-0.082475,-0.022023,-0.355738,-0.311063,0.487799,-0.287394,-0.035872,-0.196724,-0.650517,-0.043526,-0.106567,-0.144607,0.575281,...,-0.098405,0.007306,-0.225462,0.227717,0.418501,-0.232108,0.063296,-0.085332,-0.050261,0.225818,0.219674,0.039298,0.183286,0.154167,-0.026642,0.258733,0.110247,-0.307665,-0.228838,0.107233,-0.094596,-0.249216,0.322519,0.185683,0.089143
254074,1.600940,-0.741917,-0.274420,-0.475638,0.543142,0.204908,0.064439,-0.751348,-0.365603,0.124704,0.273919,-0.171351,-0.338061,0.391251,-0.210787,0.684738,-0.240798,0.656631,0.069383,-0.693975,-0.340118,-0.098305,0.139342,-0.202989,-0.166732,...,-0.045634,-0.109419,0.105692,-0.073049,-0.011347,-0.097292,-0.008840,-0.155685,0.166138,-0.020883,-0.052952,-0.087821,0.012222,-0.026534,0.046079,-0.000311,-0.017567,-0.038892,-0.101338,0.116341,0.037397,-0.020026,0.026469,0.034021,-0.028582
254075,1.586838,1.301305,0.361249,-0.200602,0.479276,0.432086,0.535784,-0.543036,-0.355524,0.036890,-0.196004,0.049752,0.474546,-0.710354,-0.462053,-0.123792,0.084388,-0.418411,0.092510,-0.075953,-0.374333,0.840167,0.638371,0.076303,0.096797,...,0.099425,0.184337,-0.029616,0.239389,0.016074,0.130965,-0.085821,0.119183,-0.104261,-0.004136,-0.058837,0.018183,-0.146152,-0.087867,-0.075908,-0.046697,0.145866,0.117566,0.081121,-0.236452,-0.291554,0.043837,0.086101,-0.163689,-0.044012
