# TF

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [2]:
import tensorflow as tf
import tensorflow.keras.backend as K

import sys
sys.path.append(r'C:\Users\81908\jupyter_notebook\poetry_work\tfgpu\atmaCup_#8\notebook\tabnet')
from tabnet_tf import TabNetRegressor, StackedTabNetRegressor

Tensorflow version 2.3.1


In [3]:
from adabelief_tf import AdaBeliefOptimizer

In [4]:
import tensorflow as tf


def build_callbacks(
    model_path, factor=0.1, mode="auto", monitor="val_loss", patience=0, verbose=0
):
    early_stopping = tf.keras.callbacks.EarlyStopping(
        mode=mode, monitor=monitor, patience=patience, verbose=verbose
    )
    model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
        model_path, mode=mode, monitor=monitor, save_best_only=True, verbose=verbose
    )
    reduce_lr_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(
        factor=factor, monitor=monitor, mode=mode, verbose=verbose
    )

    return [early_stopping, model_checkpoint, reduce_lr_on_plateau]

# base

In [5]:
import os
import gc
import re
import math
import pickle
import joblib
import warnings

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from sklearn.model_selection import KFold, GroupKFold

import lightgbm as lgb

warnings.simplefilter("ignore")
pd.set_option('display.max_columns', None)

In [6]:
import random as rn
import numpy as np


def set_seed(seed=0):
    os.environ["PYTHONHASHSEED"] = str(seed)

    rn.seed(seed)
    np.random.seed(seed)

In [7]:
from sklearn.metrics import mean_squared_log_error


def score(y, y_pred):
    RMSLE = np.sqrt(np.mean(((np.log(y + 1) - np.log(y_pred + 1)) ** 2)))
    #RMSLE = mean_squared_log_error(y, y_pred) ** 0.5
    return RMSLE

# Data load

In [8]:
import pandas as pd

DATADIR = r"C:\Users\81908\jupyter_notebook\poetry_work\tfgpu\atmaCup_#8\data\atmacup08-dataset"
train = pd.read_csv(f"{DATADIR}/train.csv")
test = pd.read_csv(f"{DATADIR}/test.csv")
df = pd.concat([train, test], axis=0)

# 前処理

In [9]:
import numpy as np

# tbd(確認中)を欠損にする
df["User_Score"] = df["User_Score"].replace("tbd", np.nan)

In [10]:
# -1で補完
cate_cols = [
    "Name",
    "Platform",
    "Year_of_Release",
    "Genre",
    "Publisher",
    "Developer",
    "Rating",
]
for col in cate_cols:
    df[col].fillna(-1, inplace=True)

In [11]:
def impute_null_add_flag_col(
    df, strategy="mean", cols_with_missing=None, fill_value=None
):
    """欠損値を補間して欠損フラグ列を追加する
    fill_value はstrategy="constant"の時のみ有効になる補間する定数
    """
    from sklearn.impute import SimpleImputer

    df_plus = df.copy()

    for col in cols_with_missing:
        # 欠損フラグ列を追加
        df_plus[col + "_was_missing"] = df[col].isnull()
        df_plus[col + "_was_missing"] = df_plus[col + "_was_missing"].astype(int)
        # 欠損値を平均値で補間
        my_imputer = SimpleImputer(strategy=strategy, fill_value=fill_value)
        df_plus[col] = my_imputer.fit_transform(df[[col]])

    return df_plus


df = impute_null_add_flag_col(
    df, strategy="most_frequent", cols_with_missing=["User_Score"]
)  # 最頻値で補間

df = impute_null_add_flag_col(
    df,
    cols_with_missing=[
        "Critic_Score",
        "Critic_Count",
        "User_Count",
    ],
)  # 平均値で補間（数値列のみ）

In [12]:
# User_Scoreを数値列にする
df["User_Score"] = df["User_Score"].astype("float")

# User_Scoreを文字列にする
df["Year_of_Release"] = df["Year_of_Release"].astype("str")

In [13]:
# ラベルエンコディング
cate_cols = df.select_dtypes(include=["object", "category", "bool"]).columns.to_list()
for col in cate_cols:
    df[col], uni = pd.factorize(df[col])

In [14]:
train = df.iloc[: train.shape[0]]
test = df.iloc[train.shape[0] :].reset_index(drop=True)

# 目的変数
sales_cols = [
    "NA_Sales",
    "EU_Sales",
    "JP_Sales",
    "Other_Sales",
    "Global_Sales",
]

train_drop_sales = train.drop(sales_cols, axis=1)
test = test.drop(sales_cols, axis=1)

In [15]:
# 欠損データ確認
_df = pd.DataFrame({"is_null": df.isnull().sum()})
_df[_df["is_null"] > 0]

Unnamed: 0,is_null
NA_Sales,8360
EU_Sales,8360
JP_Sales,8360
Other_Sales,8360
Global_Sales,8360


# FE

In [16]:
df = pd.concat([train_drop_sales, test], axis=0)

In [17]:
# 行単位で統計量とる
def add_num_row_agg(df_all, agg_num_cols):
    """行単位の統計量列追加
    agg_num_cols は数値列だけでないとエラー"""
    import warnings

    warnings.filterwarnings("ignore")

    df = df_all[agg_num_cols]
    cols = df.columns.to_list()
    cols = map(str, cols)  # 文字列にする
    col_name = "_".join(cols)

    df_all[f"row_{col_name}_sum"] = df.sum(axis=1)
    df_all[f"row_{col_name}_mean"] = df.mean(axis=1)
    df_all[f"row_{col_name}_std"] = df.std(axis=1)
    # df_all[f"row_{col_name}_skew"] = df.skew(axis=1)  # 歪度  # nanになるから入れない
    
    return df_all

    
df = add_num_row_agg(df, ["Critic_Score", "User_Score"])
df = add_num_row_agg(df, ["Critic_Count", "User_Count"])

In [18]:
import numpy as np
import pandas as pd


# A列でグループして集計したB列は意味がありそうと仮説たててから統計値列作ること
# 目的変数をキーにして集計するとリークしたターゲットエンコーディングになるため説明変数同士で行うこと
def grouping(df, cols, agg_dict, prefix=""):
    """特定のカラムについてgroup化された特徴量の作成を行う
    Args:
        df (pd.DataFrame): 特徴量作成のもととなるdataframe
        cols (str or list): group by処理のkeyとなるカラム (listで複数指定可能)
        agg_dict (dict): 特徴量作成を行いたいカラム/集計方法を指定するdictionary
        prefix (str): 集約後のカラムに付与するprefix name

    Returns:
        df (pd.DataFrame): 特定のカラムについてgroup化された特徴量群
    """
    group_df = df.groupby(cols).agg(agg_dict)
    group_df.columns = [prefix + c[0] + "_" + c[1] for c in list(group_df.columns)]
    group_df.reset_index(inplace=True)

    return group_df

class AggUtil():
    ############## カテゴリ列 vs. 数値列について ##############
    @staticmethod
    def percentile(n):
        """パーセンタイル"""
        def percentile_(x):
            return np.percentile(x, n)
        percentile_.__name__ = "percentile_%s" % n
        return percentile_

    @staticmethod
    def diff_percentile(n1, n2):
        """パーセンタイルの差"""
        def diff_percentile_(x):
            p1 = np.percentile(x, n1)
            p2 = np.percentile(x, n2)
            return p1 - p2
        diff_percentile_.__name__ = f"diff_percentile_{n1}-{n2}"
        return diff_percentile_

    @staticmethod
    def ratio_percentile(n1, n2):
        """パーセンタイルの比"""
        def ratio_percentile_(x):
            p1 = np.percentile(x, n1)
            p2 = np.percentile(x, n2)
            return p1 / p2
        ratio_percentile_.__name__ = f"ratio_percentile_{n1}-{n2}"
        return ratio_percentile_
    
    @staticmethod
    def mean_var():
        """平均分散"""
        def mean_var_(x):
            x = x.dropna()
            return np.std(x) / np.mean(x)
        mean_var_.__name__ = f"mean_var"
        return mean_var_
    
    @staticmethod
    def diff_mean():
        """平均との差の中央値(aggは集計値でないとエラーになるから中央値をとる)"""
        def diff_mean_(x):
            x = x.dropna()
            return np.median(x - np.mean(x))
        diff_mean_.__name__ = f"diff_mean"
        return diff_mean_
    
    @staticmethod
    def ratio_mean():
        """平均との比の中央値(aggは一意な値でないとエラーになるから中央値をとる)"""
        def ratio_mean_(x):
            x = x.dropna()
            return np.median(x / np.mean(x))
        ratio_mean_.__name__ = f"ratio_mean"
        return ratio_mean_
    
    @staticmethod
    def hl_ratio():
        """平均より高いサンプル数と低いサンプル数の比率"""
        def hl_ratio_(x):
            x = x.dropna()
            n_high = x[x >= np.mean(x)].shape[0]
            n_low = x[x < np.mean(x)].shape[0]
            if n_low == 0:
                return 1.0
            else:
                return n_high / n_low
        hl_ratio_.__name__ = f"hl_ratio"
        return hl_ratio_
    
    @staticmethod
    def ratio_range():
        """最大/最小"""
        def ratio_range_(x):
            x = x.dropna()
            if np.min(x) == 0:
                return 1.0
            else:
                return np.max(x) / np.min(x)
        ratio_range_.__name__ = f"ratio_range"
        return ratio_range_
    
    @staticmethod
    def beyond1std():
        """1stdを超える比率"""
        def beyond1std_(x):
            x = x.dropna()
            return x[np.abs(x) > np.abs(np.std(x))].shape[0] / x.shape[0]
        beyond1std_.__name__ = "beyond1std"
        return beyond1std_
    
    @staticmethod
    def zscore():
        """Zスコアの中央値(aggは一意な値でないとエラーになるから中央値をとる)"""
        def zscore_(x):
            x = x.dropna()
            return np.median((x - np.mean(x)) / np.std(x))
        zscore_.__name__ = "zscore"
        return zscore_
    ######################################################
    
    ############## カテゴリ列 vs. カテゴリ列について ##############
    @staticmethod
    def freq_entropy():
        """出現頻度のエントロピー"""
        from scipy.stats import entropy
        def freq_entropy_(x):
            return entropy(x.value_counts().values)
        freq_entropy_.__name__ = "freq_entropy"
        return freq_entropy_
    
    @staticmethod
    def freq1name():
        """最も頻繁に出現するカテゴリの数"""
        def freq1name_(x):
            return x.value_counts().sort_values(ascending=False)[0]
        freq1name_.__name__ = "freq1name"
        return freq1name_
    
    @staticmethod
    def freq1ratio():
        """最も頻繁に出現するカテゴリ/グループの数"""
        def freq1ratio_(x):
            frq = x.value_counts().sort_values(ascending=False)
            return frq[0] / frq.shape[0]
        freq1ratio_.__name__ = "freq1ratio"
        return freq1ratio_
    #########################################################


# 集計する数値列指定
value_agg = {
    "User_Count": [
        "max",
        "min",
        "mean",
        #"std",  # 標準偏差
        #"skew",  # 歪度
        #pd.DataFrame.kurt,  # 尖度
    ],
    "Critic_Count": [
        "max",
        "min",
        "mean",
        #"std",  # 標準偏差
        #"skew",  # 歪度
        #pd.DataFrame.kurt,  # 尖度
    ],
    "User_Score": [
        "max",
        "min",
        "mean",
        #"std",  # 標準偏差
        #"skew",  # 歪度
        #pd.DataFrame.kurt,  # 尖度
    ],
    "Critic_Score": [
        "max",
        "min",
        "mean",
        #"std",  # 標準偏差
        #"skew",  # 歪度
        #pd.DataFrame.kurt,  # 尖度
    ],
}
# グループ化するカテゴリ列でループ
for key in ["Platform", "Genre", "Publisher", "Developer", "Rating"]:
    feature_df = grouping(df, key, value_agg, prefix=key + "_")
    df = pd.merge(df, feature_df, how="left", on=key)

In [19]:
df["Critic_Score_*_Critic_Count"] = df["Critic_Score"] * df["Critic_Count"]
df["User_Score_*_User_Count"] = df["User_Score"] * df["User_Count"]
df["Critic_Score_*_User_Score"] = df["Critic_Score"] * df["User_Score"]
df["Critic_Count_*_User_Count"] = df["Critic_Count"] * df["User_Count"]
df["Critic_Count_+_User_Count"] = df["Critic_Count"] + df["User_Count"]
df["Critic_Count_-_User_Count"] = df["Critic_Count"] - df["User_Count"]
df["Critic_Count_/_all_Count"] = df["Critic_Count"] / df["Critic_Count_+_User_Count"]

In [20]:
## KMeansでクラスタリングした列追加
#
#from sklearn.cluster import KMeans
#
#def fe_cluster(df, kind, features, n_clusters=100, SEED=42, is_dummies=False):
#    df_ = df[features].copy()
#    kmeans_cells = KMeans(n_clusters=n_clusters, random_state=SEED).fit(df_)
#    df[f'clusters_{kind}'] = kmeans_cells.predict(df_.values)
#    df = pd.get_dummies(df, columns=[f'clusters_{kind}']) if is_dummies else df
#
#    return df
#
#df = fe_cluster(df, kind="cate_cols", 
#                features=["Name", "Genre", "Publisher", "Developer", "Rating"], 
#                n_clusters=3000)
#df = fe_cluster(df, kind="Score_Count", 
#                features=["Critic_Score", "User_Score", "Critic_Count", "User_Count"], 
#                n_clusters=300)

In [21]:
# 種類数少な目の列はダミー化して残す

float_cols = df.columns.to_list()

for col in ["Critic_Count_was_missing", "Critic_Score_was_missing", "User_Count_was_missing", "User_Score_was_missing"]:
    float_cols.remove(col)

df = pd.get_dummies(df, columns=[f'Rating'])
df = pd.get_dummies(df, columns=[f'Genre'])
df = pd.get_dummies(df, columns=[f'Platform'])

# 種類数少な目の列は消しとく
df = df.drop(["Name", "Year_of_Release", "Publisher", "Developer"], axis=1)

float_cols = np.intersect1d(np.array(float_cols), np.array(df.columns.to_list()))
float_cols

array(['Critic_Count', 'Critic_Count_*_User_Count',
       'Critic_Count_+_User_Count', 'Critic_Count_-_User_Count',
       'Critic_Count_/_all_Count', 'Critic_Score',
       'Critic_Score_*_Critic_Count', 'Critic_Score_*_User_Score',
       'Developer_Critic_Count_max', 'Developer_Critic_Count_mean',
       'Developer_Critic_Count_min', 'Developer_Critic_Score_max',
       'Developer_Critic_Score_mean', 'Developer_Critic_Score_min',
       'Developer_User_Count_max', 'Developer_User_Count_mean',
       'Developer_User_Count_min', 'Developer_User_Score_max',
       'Developer_User_Score_mean', 'Developer_User_Score_min',
       'Genre_Critic_Count_max', 'Genre_Critic_Count_mean',
       'Genre_Critic_Count_min', 'Genre_Critic_Score_max',
       'Genre_Critic_Score_mean', 'Genre_Critic_Score_min',
       'Genre_User_Count_max', 'Genre_User_Count_mean',
       'Genre_User_Count_min', 'Genre_User_Score_max',
       'Genre_User_Score_mean', 'Genre_User_Score_min',
       'Platform_Critic_C

In [22]:
## RankGauss
#from sklearn.preprocessing import QuantileTransformer
#
#_df = df[float_cols]
#
#qt = QuantileTransformer(n_quantiles=100, random_state=42, output_distribution="normal")
#df[_df.columns] = qt.fit_transform(_df)

In [23]:
# 欠損データ確認
_df = pd.DataFrame({"is_null": df.isnull().sum()})
_df[_df["is_null"] > 0]

Unnamed: 0,is_null


In [24]:
train_drop_sales = df.iloc[: train.shape[0]]
test = df.iloc[train.shape[0] :].reset_index(drop=True)

In [25]:
train_drop_sales

Unnamed: 0,Critic_Score,Critic_Count,User_Score,User_Count,User_Score_was_missing,Critic_Score_was_missing,Critic_Count_was_missing,User_Count_was_missing,row_Critic_Score_User_Score_sum,row_Critic_Score_User_Score_mean,row_Critic_Score_User_Score_std,row_Critic_Count_User_Count_sum,row_Critic_Count_User_Count_mean,row_Critic_Count_User_Count_std,Platform_User_Count_max,Platform_User_Count_min,Platform_User_Count_mean,Platform_Critic_Count_max,Platform_Critic_Count_min,Platform_Critic_Count_mean,Platform_User_Score_max,Platform_User_Score_min,Platform_User_Score_mean,Platform_Critic_Score_max,Platform_Critic_Score_min,Platform_Critic_Score_mean,Genre_User_Count_max,Genre_User_Count_min,Genre_User_Count_mean,Genre_Critic_Count_max,Genre_Critic_Count_min,Genre_Critic_Count_mean,Genre_User_Score_max,Genre_User_Score_min,Genre_User_Score_mean,Genre_Critic_Score_max,Genre_Critic_Score_min,Genre_Critic_Score_mean,Publisher_User_Count_max,Publisher_User_Count_min,Publisher_User_Count_mean,Publisher_Critic_Count_max,Publisher_Critic_Count_min,Publisher_Critic_Count_mean,Publisher_User_Score_max,Publisher_User_Score_min,Publisher_User_Score_mean,Publisher_Critic_Score_max,Publisher_Critic_Score_min,Publisher_Critic_Score_mean,Developer_User_Count_max,Developer_User_Count_min,Developer_User_Count_mean,Developer_Critic_Count_max,Developer_Critic_Count_min,Developer_Critic_Count_mean,Developer_User_Score_max,Developer_User_Score_min,Developer_User_Score_mean,Developer_Critic_Score_max,Developer_Critic_Score_min,Developer_Critic_Score_mean,Rating_User_Count_max,Rating_User_Count_min,Rating_User_Count_mean,Rating_Critic_Count_max,Rating_Critic_Count_min,Rating_Critic_Count_mean,Rating_User_Score_max,Rating_User_Score_min,Rating_User_Score_mean,Rating_Critic_Score_max,Rating_Critic_Score_min,Rating_Critic_Score_mean,Critic_Score_*_Critic_Count,User_Score_*_User_Count,Critic_Score_*_User_Score,Critic_Count_*_User_Count,Critic_Count_+_User_Count,Critic_Count_-_User_Count,Critic_Count_/_all_Count,Rating_0,Rating_1,Rating_2,Rating_3,Rating_4,Rating_5,Rating_6,Rating_7,Rating_8,Genre_0,Genre_1,Genre_2,Genre_3,Genre_4,Genre_5,Genre_6,Genre_7,Genre_8,Genre_9,Genre_10,Genre_11,Genre_12,Platform_0,Platform_1,Platform_2,Platform_3,Platform_4,Platform_5,Platform_6,Platform_7,Platform_8,Platform_9,Platform_10,Platform_11,Platform_12,Platform_13,Platform_14,Platform_15,Platform_16,Platform_17,Platform_18,Platform_19,Platform_20,Platform_21,Platform_22,Platform_23,Platform_24,Platform_25,Platform_26,Platform_27,Platform_28,Platform_29,Platform_30
0,74.000000,17.000000,7.9,22.000000,0,0,0,0,81.900000,40.950000,46.739758,39.000000,19.500000,3.535534,2147.000000,4.000000,115.942220,87.000000,4.000000,24.094851,9.3,0.2,7.351439,97.000000,19.000000,66.244882,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,2679.000000,4.0,177.911378,89.0,4.0,25.355691,8.9,2.5,7.232340,96.0,24.0,69.427020,162.229908,4.0,48.974924,60.000000,4.0,21.126255,9.1,5.2,7.472727,86.0,51.000000,71.973186,5999.0,4.0,101.723715,89.0,4.0,23.408428,9.3,0.6,7.106056,95.0,19.0,67.229041,1258.000000,173.800000,584.600000,374.00000,39.000000,-5.000000,0.435897,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,78.000000,22.000000,6.6,28.000000,0,0,0,0,84.600000,42.300000,50.487424,50.000000,25.000000,4.242641,2147.000000,4.000000,115.942220,87.000000,4.000000,24.094851,9.3,0.2,7.351439,97.000000,19.000000,66.244882,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,1509.000000,4.0,112.854538,74.0,4.0,24.819702,9.1,0.6,7.088889,93.0,43.0,69.938698,162.229908,4.0,48.974924,60.000000,4.0,21.126255,9.1,5.2,7.472727,86.0,51.000000,71.973186,5999.0,4.0,101.723715,89.0,4.0,23.408428,9.3,0.6,7.106056,95.0,19.0,67.229041,1716.000000,184.800000,514.800000,616.00000,50.000000,-6.000000,0.440000,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,73.000000,5.000000,7.4,10.000000,0,0,0,0,80.400000,40.200000,46.386205,15.000000,7.500000,3.535534,565.000000,4.000000,116.331817,79.000000,4.000000,25.228729,9.7,0.6,7.600910,91.000000,28.000000,68.377879,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,2679.000000,4.0,177.911378,89.0,4.0,25.355691,8.9,2.5,7.232340,96.0,24.0,69.427020,162.229908,4.0,48.974924,60.000000,4.0,21.126255,9.1,5.2,7.472727,86.0,51.000000,71.973186,5999.0,4.0,101.723715,89.0,4.0,23.408428,9.3,0.6,7.106056,95.0,19.0,67.229041,365.000000,74.000000,540.200000,50.00000,15.000000,-5.000000,0.333333,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,162.229908,162.229908,162.229908,26.360821,26.360821,26.360821,7.8,7.8,7.800000,68.967679,68.967679,68.967679,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,1660.000000,4.0,114.034369,70.0,4.0,25.188571,9.2,2.0,7.507629,93.0,26.0,67.128215,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.000000,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,76.000000,8.000000,7.8,13.000000,0,0,0,0,83.800000,41.900000,48.224682,21.000000,10.500000,3.535534,2147.000000,4.000000,115.942220,87.000000,4.000000,24.094851,9.3,0.2,7.351439,97.000000,19.000000,66.244882,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,2679.000000,4.0,177.911378,89.0,4.0,25.355691,8.9,2.5,7.232340,96.0,24.0,69.427020,162.229908,4.0,48.974924,60.000000,4.0,21.126255,9.1,5.2,7.472727,86.0,51.000000,71.973186,5999.0,4.0,101.723715,89.0,4.0,23.408428,9.3,0.6,7.106056,95.0,19.0,67.229041,608.000000,101.400000,592.800000,104.00000,21.000000,-5.000000,0.380952,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8354,78.000000,57.000000,8.0,569.000000,0,0,0,0,86.000000,43.000000,49.497475,626.000000,313.000000,362.038672,10665.000000,4.000000,520.406469,104.000000,4.000000,27.505598,9.3,1.4,7.216940,96.000000,33.000000,74.077648,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,2191.000000,6.0,298.277367,71.0,5.0,29.082025,8.9,3.3,7.247826,89.0,49.0,70.858323,569.000000,6.0,240.000000,57.000000,13.0,31.072164,8.7,6.8,7.500000,83.0,68.967679,75.393536,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,4446.000000,4552.000000,624.000000,32433.00000,626.000000,-512.000000,0.091054,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
8355,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,10665.000000,4.000000,520.406469,104.000000,4.000000,27.505598,9.3,1.4,7.216940,96.000000,33.000000,74.077648,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,1340.000000,4.0,174.415834,30.0,6.0,23.725285,7.8,2.8,6.976471,76.0,33.0,64.683067,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.000000,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
8356,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,1228.000000,4.000000,140.417693,86.000000,4.000000,26.556889,9.3,1.8,7.646528,93.000000,33.000000,69.474342,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,162.229908,5.0,151.242697,27.0,4.0,25.669944,8.3,6.5,7.768421,79.0,47.0,68.548570,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.000000,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
8357,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,10179.000000,4.000000,394.985413,113.000000,4.000000,34.444976,9.2,1.5,7.112214,97.000000,19.000000,70.970592,10665.0,4.0,240.748693,98.0,4.0,29.372204,9.7,1.4,7.710600,96.0,35.0,70.778226,3742.000000,4.0,140.489653,86.0,4.0,28.183545,9.4,2.1,7.616301,97.0,28.0,69.356880,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.000000,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0


In [26]:
test

Unnamed: 0,Critic_Score,Critic_Count,User_Score,User_Count,User_Score_was_missing,Critic_Score_was_missing,Critic_Count_was_missing,User_Count_was_missing,row_Critic_Score_User_Score_sum,row_Critic_Score_User_Score_mean,row_Critic_Score_User_Score_std,row_Critic_Count_User_Count_sum,row_Critic_Count_User_Count_mean,row_Critic_Count_User_Count_std,Platform_User_Count_max,Platform_User_Count_min,Platform_User_Count_mean,Platform_Critic_Count_max,Platform_Critic_Count_min,Platform_Critic_Count_mean,Platform_User_Score_max,Platform_User_Score_min,Platform_User_Score_mean,Platform_Critic_Score_max,Platform_Critic_Score_min,Platform_Critic_Score_mean,Genre_User_Count_max,Genre_User_Count_min,Genre_User_Count_mean,Genre_Critic_Count_max,Genre_Critic_Count_min,Genre_Critic_Count_mean,Genre_User_Score_max,Genre_User_Score_min,Genre_User_Score_mean,Genre_Critic_Score_max,Genre_Critic_Score_min,Genre_Critic_Score_mean,Publisher_User_Count_max,Publisher_User_Count_min,Publisher_User_Count_mean,Publisher_Critic_Count_max,Publisher_Critic_Count_min,Publisher_Critic_Count_mean,Publisher_User_Score_max,Publisher_User_Score_min,Publisher_User_Score_mean,Publisher_Critic_Score_max,Publisher_Critic_Score_min,Publisher_Critic_Score_mean,Developer_User_Count_max,Developer_User_Count_min,Developer_User_Count_mean,Developer_Critic_Count_max,Developer_Critic_Count_min,Developer_Critic_Count_mean,Developer_User_Score_max,Developer_User_Score_min,Developer_User_Score_mean,Developer_Critic_Score_max,Developer_Critic_Score_min,Developer_Critic_Score_mean,Rating_User_Count_max,Rating_User_Count_min,Rating_User_Count_mean,Rating_Critic_Count_max,Rating_Critic_Count_min,Rating_Critic_Count_mean,Rating_User_Score_max,Rating_User_Score_min,Rating_User_Score_mean,Rating_Critic_Score_max,Rating_Critic_Score_min,Rating_Critic_Score_mean,Critic_Score_*_Critic_Count,User_Score_*_User_Count,Critic_Score_*_User_Score,Critic_Count_*_User_Count,Critic_Count_+_User_Count,Critic_Count_-_User_Count,Critic_Count_/_all_Count,Rating_0,Rating_1,Rating_2,Rating_3,Rating_4,Rating_5,Rating_6,Rating_7,Rating_8,Genre_0,Genre_1,Genre_2,Genre_3,Genre_4,Genre_5,Genre_6,Genre_7,Genre_8,Genre_9,Genre_10,Genre_11,Genre_12,Platform_0,Platform_1,Platform_2,Platform_3,Platform_4,Platform_5,Platform_6,Platform_7,Platform_8,Platform_9,Platform_10,Platform_11,Platform_12,Platform_13,Platform_14,Platform_15,Platform_16,Platform_17,Platform_18,Platform_19,Platform_20,Platform_21,Platform_22,Platform_23,Platform_24,Platform_25,Platform_26,Platform_27,Platform_28,Platform_29,Platform_30
0,84.000000,23.000000,8.0,19.000000,0,0,0,0,92.000000,46.000000,53.740115,42.000000,21.000000,2.828427,1283.0,4.0,64.110095,91.000000,4.0,26.369807,9.3,0.5,7.587015,97.0,19.0,69.752185,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,3552.0,4.0,126.002039,78.0,4.0,25.415598,9.3,2.5,7.515152,92.0,28.0,68.641884,528.000000,8.0,113.300000,53.000000,9.0,28.800000,9.1,5.5,8.045000,85.0,69.0,79.250000,10665.0,4.0,472.765416,107.0,4.0,38.224482,9.4,1.0,7.178311,98.0,13.0,71.652216,1932.000000,152.000000,672.000000,437.00000,42.000000,4.000000,0.547619,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,91.000000,17.000000,9.0,132.000000,0,0,0,0,100.000000,50.000000,57.982756,149.000000,74.500000,81.317280,1282.0,4.0,150.268968,26.360821,4.0,23.696523,9.4,1.2,7.808855,98.0,26.0,69.393296,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,3552.0,4.0,126.002039,78.0,4.0,25.415598,9.3,2.5,7.515152,92.0,28.0,68.641884,1288.000000,6.0,102.458393,102.000000,6.0,33.108866,9.3,5.8,7.876000,91.0,65.0,75.517414,7064.0,4.0,116.653145,113.0,4.0,28.086735,9.7,0.5,7.356062,98.0,21.0,68.843507,1547.000000,1188.000000,819.000000,2244.00000,149.000000,-115.000000,0.114094,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
2,87.000000,28.000000,8.5,39.000000,0,0,0,0,95.500000,47.750000,55.507882,67.000000,33.500000,7.778175,1283.0,4.0,64.110095,91.000000,4.0,26.369807,9.3,0.5,7.587015,97.0,19.0,69.752185,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,3943.0,4.0,148.985213,86.0,4.0,24.693735,9.6,2.2,7.632614,96.0,30.0,68.722596,1026.000000,4.0,284.185826,68.000000,9.0,27.750000,9.6,6.3,8.516667,96.0,48.0,77.750000,10665.0,4.0,472.765416,107.0,4.0,38.224482,9.4,1.0,7.178311,98.0,13.0,71.652216,2436.000000,331.500000,739.500000,1092.00000,67.000000,-11.000000,0.417910,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,70.000000,54.000000,6.9,180.000000,0,0,0,0,76.900000,38.450000,44.618438,234.000000,117.000000,89.095454,8713.0,4.0,181.309867,100.000000,3.0,37.800193,9.0,0.7,6.945404,98.0,19.0,68.713009,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,3943.0,4.0,148.985213,86.0,4.0,24.693735,9.6,2.2,7.632614,96.0,30.0,68.722596,180.000000,4.0,56.667715,59.000000,4.0,31.425492,9.1,2.3,6.550000,79.0,42.0,60.089440,10665.0,4.0,472.765416,107.0,4.0,38.224482,9.4,1.0,7.178311,98.0,13.0,71.652216,3780.000000,1242.000000,483.000000,9720.00000,234.000000,-126.000000,0.230769,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,71.000000,41.000000,6.9,143.000000,0,0,0,0,77.900000,38.950000,45.325545,184.000000,92.000000,72.124892,8003.0,4.0,190.998781,107.000000,4.0,31.006296,9.1,0.2,7.092712,98.0,13.0,69.839582,8003.0,4.0,176.706842,106.0,4.0,27.157274,9.5,0.3,7.394926,98.0,19.0,67.656132,3943.0,4.0,148.985213,86.0,4.0,24.693735,9.6,2.2,7.632614,96.0,30.0,68.722596,180.000000,4.0,56.667715,59.000000,4.0,31.425492,9.1,2.3,6.550000,79.0,42.0,60.089440,10665.0,4.0,472.765416,107.0,4.0,38.224482,9.4,1.0,7.178311,98.0,13.0,71.652216,2911.000000,986.700000,489.900000,5863.00000,184.000000,-102.000000,0.222826,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8355,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,1228.0,4.0,140.417693,86.000000,4.0,26.556889,9.3,1.8,7.646528,93.0,33.0,69.474342,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,10665.0,4.0,177.218272,89.0,4.0,26.529438,9.3,0.9,7.644089,98.0,26.0,68.219029,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.0,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
8356,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,10179.0,4.0,394.985413,113.000000,4.0,34.444976,9.2,1.5,7.112214,97.0,19.0,70.970592,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,10665.0,4.0,177.218272,89.0,4.0,26.529438,9.3,0.9,7.644089,98.0,26.0,68.219029,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.0,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
8357,71.000000,15.000000,7.1,71.000000,0,0,0,0,78.100000,39.050000,45.184123,86.000000,43.000000,39.597980,10665.0,4.0,520.406469,104.000000,4.0,27.505598,9.3,1.4,7.216940,96.0,33.0,74.077648,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,1473.0,4.0,172.128316,30.0,4.0,21.558370,8.6,2.3,7.327027,83.0,42.0,66.924981,162.229908,7.0,96.057477,32.000000,9.0,20.590205,8.6,7.1,7.925000,82.0,63.0,71.241920,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1065.000000,504.100000,504.100000,1065.00000,86.000000,-56.000000,0.174419,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
8358,68.967679,26.360821,7.8,162.229908,1,1,1,1,76.767679,38.383839,43.252080,188.590729,94.295364,96.073953,10665.0,4.0,520.406469,104.000000,4.0,27.505598,9.3,1.4,7.216940,96.0,33.0,74.077648,6157.0,4.0,177.707944,104.0,3.0,27.205670,9.4,0.6,7.570132,94.0,35.0,70.346538,1473.0,4.0,172.128316,30.0,4.0,21.558370,8.6,2.3,7.327027,83.0,42.0,66.924981,162.229908,4.0,162.001986,26.360821,5.0,26.350529,8.2,3.2,7.797207,83.0,43.0,68.961517,1010.0,4.0,161.611977,57.0,4.0,26.273076,9.2,3.1,7.789393,93.0,31.0,68.963347,1818.044624,1265.393281,537.947892,4276.51355,188.590729,-135.869087,0.139778,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


# tabnet

In [27]:
n_seeds = 1
n_splits = 5
shuffle = True

batch_size = 64
#batch_size = 8
factor = 0.5
patience = 30
lr = 0.001
#fit_params = {"epochs": 1_000, "verbose": 1}
fit_params = {"epochs": 300, "verbose": 1}

params = dict(
    epsilon=1e-05,
    feature_columns=None,
    virtual_batch_size=None,
    # num_layers=2,
    num_decision_steps=1,
    norm_type="batch",
    #norm_type="group",
    num_groups=-1,
    batch_momentum=0.9,
    relaxation_factor=1.2,
    sparsity_coefficient=0.0001,
    feature_dim=1024,
    output_dim=209,
)
with open(f"params.pkl", "wb") as f:
    pickle.dump(params, f)

#DEBUG = True
DEBUG = False
if DEBUG:
    n_seeds = 1
    n_splits = 2
    batch_size = 1024
    fit_params = {"epochs": 10, "verbose": 1}
    print("DEBUG")

In [28]:
## test setとあまり被りのない特徴量は除外  https://www.guruguru.science/competitions/13/discussions/df06ef19-981d-4666-a0c0-22f62ee26640/
#
#inbalance = ["Publisher", "Developer", "Name"]
#train_drop_sales = train_drop_sales.drop(inbalance, axis=1)
#test = test.drop(inbalance, axis=1)
#
#cate_cols = list(set(cate_cols) - set(inbalance))
#print("cate_cols:", cate_cols)

In [29]:
#_train = pd.read_csv(f"{DATADIR}/train.csv")

# group_col = "Name"
group_col = "Publisher"  # https://www.guruguru.science/competitions/13/discussions/42fc473d-4450-4cfc-b924-0a5d61fd0ca7/

# GroupKFold
group = train[group_col].copy()  

# seed値が指定できるGroupKFold  https://www.guruguru.science/competitions/13/discussions/cc7167cb-3627-448a-b9eb-7afcd29fd122/
group_uni = train[group_col].copy().unique()

In [30]:
features = test.columns.to_list()

global_target_col = "Global_Sales"
target_cols = [
    "NA_Sales",
    "EU_Sales",
    "JP_Sales",
    "Other_Sales",
]

X = train_drop_sales[features]
Y = train[target_cols]
Y_global = train[global_target_col]

train_size, n_features = X.shape
_, n_classes = Y.shape

In [31]:
def root_mean_squared_logarithmic_error(y_true, y_pred):
    msle = tf.keras.losses.MeanSquaredLogarithmicError()
    return K.sqrt(msle(y_true, y_pred))


def fit_tabnet(X_train, Y_train, X_val, Y_val, model_path):
    K.clear_session()
    model = TabNetRegressor(num_regressors=n_classes, num_features=n_features, **params)
    model.compile(
        optimizer=AdaBeliefOptimizer(learning_rate=lr), 
        #loss=tf.keras.losses.MeanSquaredLogarithmicError(),  # MSLE
        loss=root_mean_squared_logarithmic_error,  # RMSLE
    )
    
    callbacks = build_callbacks(model_path, factor=factor, patience=patience)
    
    history = model.fit(
        X_train, 
        Y_train,
        batch_size=batch_size,
        callbacks=callbacks,
        validation_data=(X_val, Y_val),
        **fit_params,
    )
    
    model.load_weights(model_path)
    
    return model

In [32]:
def fit_model(Y, features_imp=None, out_pkl="Y_pred.pkl"): 
    
    Y_pred = np.zeros((train_size, n_classes))
    Y_pred = pd.DataFrame(Y_pred, columns=Y.columns, index=Y.index)

    for i in tqdm(range(n_seeds)):
        set_seed(seed=i)

        cv = KFold(n_splits=n_splits, random_state=i, shuffle=shuffle)
        #cv_split = cv.split(X, Y)  # KFold
        cv_split = cv.split(group_uni)  # GroupRandomKFold
        #cv = GroupKFold(n_splits=n_splits)
        #cv_split = cv.split(X, Y, group)

        for j, (trn_idx, val_idx) in enumerate(cv_split):        
            print(f"\n------------ fold:{j} ------------")
            # GroupRandomKFold
            tr_groups, va_groups = group_uni[trn_idx], group_uni[val_idx]
            trn_idx = group[group.isin(tr_groups)].index
            val_idx = group[group.isin(va_groups)].index
            print("len(trn_idx), len(val_idx):", len(trn_idx), len(val_idx))

            X_train, X_val = X.iloc[trn_idx], X.iloc[val_idx]
            Y_train, Y_val = Y.iloc[trn_idx], Y.iloc[val_idx]

            # importance 上位100位までの特徴量だけにする
            if features_imp is not None:
                X_train, X_val = X_train[features_imp], X_val[features_imp]
                
            model = fit_tabnet(X_train, Y_train, X_val, Y_val, f"model_seed_{i}_fold_{j}.h5")
            
            va_pred = model.predict(X_val)
            
            va_pred = np.where(va_pred<0, 0, va_pred)
            
            Y_pred.iloc[val_idx] += va_pred / n_seeds
            
            #del model
            #gc.collect()

    with open(out_pkl, "wb") as f:
        pickle.dump(Y_pred, f)

    return Y_pred, model

In [33]:
%%time
Y_pred, model = fit_model(Y)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))


------------ fold:0 ------------
len(trn_idx), len(val_idx): 6174 2185
[31mPlease check your arguments if you have upgraded adabelief-tf from version 0.0.1.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  -------------
adabelief-tf=0.0.1       1e-08  Not supported      Not supported
Current version (0.1.0)  1e-14  supported          default: True
[31mFor a complete table of recommended hyperparameters, see
[31mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[0m
Epoch 1/300


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 

Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300

------------ fold:2 ------------
len(trn_idx), len(val_idx): 5807 2552
[31mPlease check your arguments if you have upgraded adabelief-tf from version 0.0.1.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  -------------
adabelief-tf=0.0.1       1e-08  Not supported      Not 

Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300

------------ fold:3 ------------
len(trn_idx), len(val_idx): 6856 1503
[31mPlease check your arguments if you have upgraded adabelief-tf from version 0.0.1.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  -------------
adabelief-tf=0.0.1       1e-08  Not supported      Not supported
Current version (0.1.0)  1e-14  supported          default: 

Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300

------------ fold:4 ------------
len(trn_idx), len(val_idx): 7059 1300
[31mPlease check your arguments if you have upgraded adabelief-tf from version 0.0.1.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  -------------
adabelief-tf=0.0.1       1e-08  Not supported      Not supported
Current version (0.1.0)  1e-14  supported          default: True
[31mFor a complete table of recommended hyperparameters, see
[31mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[0m
Epoch 1/300


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change jus

Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 79/300
Epoch 80/300
Epoch 81/300
Epoch 82/300
Epoch 83/300
Epoch 84/300
Epoch 85/300
Epoch 86/300
Epoch 87/300
Epoch 88/300
Epoch 89/300
Epoch 90/300

Wall time: 5min 13s


In [34]:
score(Y, Y_pred)

NA_Sales       1.314835
EU_Sales       1.217597
JP_Sales       1.205844
Other_Sales    0.847330
dtype: float64

In [35]:
# デフォルトパラメ
## RankGaussなし: 1.18526
## RankGaussあり: 1.23384
# パラメ変えて
## bauch_size=8: 1.16616
## bauch_size=64: 1.14625
score(Y_global, Y_pred.sum(axis=1))

1.146255058286068

# oof

In [36]:
path = r"Y_pred.pkl"
with open(path, "rb") as f:
    Y_pred = pickle.load(f)
Y_pred

Unnamed: 0,NA_Sales,EU_Sales,JP_Sales,Other_Sales
0,17.294683,6.585598,0.567128,3.727135
1,36.258926,15.381187,1.361468,8.673420
2,10.828993,3.831598,0.876301,1.860105
3,18.843252,7.936788,1.044856,2.325323
4,15.814473,6.157887,0.478688,3.589277
...,...,...,...,...
8354,2.009439,12.884875,0.264278,3.090280
8355,1.033238,4.692636,0.000000,0.819955
8356,0.624857,0.131230,1.480749,0.162746
8357,6.802473,7.420416,2.153712,3.286767


# predict test

In [37]:
Y_test_pred = np.zeros((test.shape[0], n_classes))
Y_test_pred = pd.DataFrame(Y_test_pred, columns=target_cols, index=test.index)

for i in range(n_seeds):
    for j in range(n_splits):
        model_path = f"model_seed_{i}_fold_{j}.h5"
        model = TabNetRegressor(num_regressors=n_classes, num_features=n_features, **params)
        model(np.zeros((1, n_features)))
        model.load_weights(model_path)
            
        Y_test_pred += model.predict(test) / (n_seeds * n_splits)
            
for col in Y_test_pred.columns:
    idx = Y_test_pred[Y_test_pred[col] < 0].index
    Y_test_pred[col][idx] = 0.0
    
print(Y_test_pred.shape)

submission = pd.read_csv(f"{DATADIR}/atmaCup8_sample-submission.csv")
submission["Global_Sales"] = Y_test_pred
submission.to_csv("submission.csv", index=False)
submission



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float

Unnamed: 0,Global_Sales
0,13.087490
1,31.999891
2,18.447356
3,29.122060
4,22.518407
...,...
8355,11.180462
8356,4.882340
8357,1.346418
8358,1.541184
