# Music Genre Prediction

In [1]:
class Config:
    NB = '207'
    dataset_NB = '103'

    raw_data_dir = '../data/raw/'
    processed_data_dir = '../data/processed/'
    interim_dir = '../data/interim/'
    submission_dir = '../data/submission/'

    random_seed = 42
    n_folds = 5

    row_id = 'index'
    target = 'genre'

## Import libralies

In [2]:
import os
import gc
import warnings
warnings.filterwarnings('ignore')

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
import itertools

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
sns.set(style='white', context='notebook', palette='deep')

In [28]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff

plotly_template = dict(
    layout=go.Layout(
        template='plotly_dark',
        font=dict(
            family="Franklin Gothic",
            size=12
        ),
        height=500,
        width=1000,
    )
)


color_palette = {
    'Bin': ['#016CC9','#E876A3'],
    'Cat5': ['#E876A3', '#E0A224', '#63B70D', '#6BCFF6', '#13399E'],
}

In [17]:
import random
import joblib
import itertools
from itertools import combinations
from imblearn import FunctionSampler
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import SMOTE

from sklearn.model_selection import StratifiedKFold, GroupKFold, train_test_split
from sklearn.preprocessing import LabelEncoder

from sklearn.metrics import roc_auc_score, roc_curve, auc, f1_score, confusion_matrix
import scipy.stats as stats
import time

import tensorflow as tf
from tensorflow.keras import datasets
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import optimizers
from keras.utils import np_utils

In [5]:
df_train = pd.read_pickle(Config.processed_data_dir + f'nb{Config.dataset_NB}_train.pkl', compression='zip')
df_test = pd.read_pickle(Config.processed_data_dir + f'nb{Config.dataset_NB}_test.pkl', compression='zip')

submission = pd.read_csv(Config.raw_data_dir + 'sample_submit.csv', header=None)

df_train.shape

(4046, 285)

In [33]:
df_train.describe()

Unnamed: 0,index,genre,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,tempo_int,region_A,region_B,region_C,region_D,region_E,region_F,region_G,region_H,region_I,region_J,region_K,region_L,region_M,region_N,region_O,region_P,region_Q,region_R,region_S,region_T,unknown,duration_long,popularity_add_duration_ms,popularity_sub_duration_ms,popularity_mul_duration_ms,popularity_bigger_duration_ms,popularity_add_acousticness,popularity_sub_acousticness,popularity_mul_acousticness,popularity_bigger_acousticness,popularity_add_positiveness,popularity_sub_positiveness,popularity_mul_positiveness,popularity_bigger_positiveness,popularity_add_danceability,popularity_sub_danceability,popularity_mul_danceability,popularity_bigger_danceability,popularity_add_loudness,popularity_sub_loudness,popularity_mul_loudness,popularity_bigger_loudness,popularity_add_energy,popularity_sub_energy,popularity_mul_energy,popularity_bigger_energy,popularity_add_liveness,popularity_sub_liveness,popularity_mul_liveness,popularity_bigger_liveness,popularity_add_speechiness,popularity_sub_speechiness,popularity_mul_speechiness,popularity_bigger_speechiness,popularity_add_instrumentalness,popularity_sub_instrumentalness,popularity_mul_instrumentalness,popularity_bigger_instrumentalness,popularity_add_tempo_int,popularity_sub_tempo_int,popularity_mul_tempo_int,popularity_bigger_tempo_int,duration_ms_add_acousticness,duration_ms_sub_acousticness,duration_ms_mul_acousticness,duration_ms_bigger_acousticness,duration_ms_add_positiveness,duration_ms_sub_positiveness,duration_ms_mul_positiveness,duration_ms_bigger_positiveness,duration_ms_add_danceability,duration_ms_sub_danceability,duration_ms_mul_danceability,duration_ms_bigger_danceability,duration_ms_add_loudness,duration_ms_sub_loudness,duration_ms_mul_loudness,duration_ms_bigger_loudness,duration_ms_add_energy,duration_ms_sub_energy,duration_ms_mul_energy,duration_ms_bigger_energy,duration_ms_add_liveness,duration_ms_sub_liveness,duration_ms_mul_liveness,duration_ms_bigger_liveness,duration_ms_add_speechiness,duration_ms_sub_speechiness,duration_ms_mul_speechiness,duration_ms_bigger_speechiness,duration_ms_add_instrumentalness,duration_ms_sub_instrumentalness,duration_ms_mul_instrumentalness,duration_ms_bigger_instrumentalness,duration_ms_add_tempo_int,duration_ms_sub_tempo_int,duration_ms_mul_tempo_int,duration_ms_bigger_tempo_int,acousticness_add_positiveness,acousticness_sub_positiveness,acousticness_mul_positiveness,acousticness_bigger_positiveness,acousticness_add_danceability,acousticness_sub_danceability,acousticness_mul_danceability,acousticness_bigger_danceability,acousticness_add_loudness,acousticness_sub_loudness,acousticness_mul_loudness,acousticness_bigger_loudness,acousticness_add_energy,acousticness_sub_energy,acousticness_mul_energy,acousticness_bigger_energy,acousticness_add_liveness,acousticness_sub_liveness,acousticness_mul_liveness,acousticness_bigger_liveness,acousticness_add_speechiness,acousticness_sub_speechiness,acousticness_mul_speechiness,acousticness_bigger_speechiness,acousticness_add_instrumentalness,acousticness_sub_instrumentalness,acousticness_mul_instrumentalness,acousticness_bigger_instrumentalness,acousticness_add_tempo_int,acousticness_sub_tempo_int,acousticness_mul_tempo_int,acousticness_bigger_tempo_int,positiveness_add_danceability,positiveness_sub_danceability,positiveness_mul_danceability,positiveness_bigger_danceability,positiveness_add_loudness,positiveness_sub_loudness,positiveness_mul_loudness,positiveness_bigger_loudness,positiveness_add_energy,positiveness_sub_energy,positiveness_mul_energy,positiveness_bigger_energy,positiveness_add_liveness,positiveness_sub_liveness,positiveness_mul_liveness,positiveness_bigger_liveness,positiveness_add_speechiness,positiveness_sub_speechiness,positiveness_mul_speechiness,positiveness_bigger_speechiness,positiveness_add_instrumentalness,positiveness_sub_instrumentalness,positiveness_mul_instrumentalness,positiveness_bigger_instrumentalness,positiveness_add_tempo_int,positiveness_sub_tempo_int,positiveness_mul_tempo_int,positiveness_bigger_tempo_int,danceability_add_loudness,danceability_sub_loudness,danceability_mul_loudness,danceability_bigger_loudness,danceability_add_energy,danceability_sub_energy,danceability_mul_energy,danceability_bigger_energy,danceability_add_liveness,danceability_sub_liveness,danceability_mul_liveness,danceability_bigger_liveness,danceability_add_speechiness,danceability_sub_speechiness,danceability_mul_speechiness,danceability_bigger_speechiness,danceability_add_instrumentalness,danceability_sub_instrumentalness,danceability_mul_instrumentalness,danceability_bigger_instrumentalness,danceability_add_tempo_int,danceability_sub_tempo_int,danceability_mul_tempo_int,danceability_bigger_tempo_int,loudness_add_energy,loudness_sub_energy,loudness_mul_energy,loudness_bigger_energy,loudness_add_liveness,loudness_sub_liveness,loudness_mul_liveness,loudness_bigger_liveness,loudness_add_speechiness,loudness_sub_speechiness,loudness_mul_speechiness,loudness_bigger_speechiness,loudness_add_instrumentalness,loudness_sub_instrumentalness,loudness_mul_instrumentalness,loudness_bigger_instrumentalness,loudness_add_tempo_int,loudness_sub_tempo_int,loudness_mul_tempo_int,loudness_bigger_tempo_int,energy_add_liveness,energy_sub_liveness,energy_mul_liveness,energy_bigger_liveness,energy_add_speechiness,energy_sub_speechiness,energy_mul_speechiness,energy_bigger_speechiness,energy_add_instrumentalness,energy_sub_instrumentalness,energy_mul_instrumentalness,energy_bigger_instrumentalness,energy_add_tempo_int,energy_sub_tempo_int,energy_mul_tempo_int,energy_bigger_tempo_int,liveness_add_speechiness,liveness_sub_speechiness,liveness_mul_speechiness,liveness_bigger_speechiness,liveness_add_instrumentalness,liveness_sub_instrumentalness,liveness_mul_instrumentalness,liveness_bigger_instrumentalness,liveness_add_tempo_int,liveness_sub_tempo_int,liveness_mul_tempo_int,liveness_bigger_tempo_int,speechiness_add_instrumentalness,speechiness_sub_instrumentalness,speechiness_mul_instrumentalness,speechiness_bigger_instrumentalness,speechiness_add_tempo_int,speechiness_sub_tempo_int,speechiness_mul_tempo_int,speechiness_bigger_tempo_int,instrumentalness_add_tempo_int,instrumentalness_sub_tempo_int,instrumentalness_mul_tempo_int,instrumentalness_bigger_tempo_int,PCA1,PCA2,PCA3,PCA4,PCA5,PCA6,PCA7,PCA8,PCA9,PCA10,PCA11,PCA12,PCA13,PCA14,PCA15,PCA16,PCA17,PCA18,PCA19,PCA20,PCA21,PCA22,PCA23,PCA24,PCA25,PCA26,PCA27,PCA28,PCA29,PCA30
count,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0
mean,2022.5,7.28176,0.513847,0.42715,0.344084,0.468205,0.51977,0.644845,0.602985,0.234547,0.302811,0.17338,134.53213,0.003213,0.080326,0.020761,0.049184,0.175482,0.033613,0.014829,0.044736,0.185863,0.006179,0.056846,0.027929,0.000741,0.010875,0.034108,0.086258,0.008898,0.012358,0.014335,0.042017,0.091448,0.0,0.218106,0.225913,0.218106,0.68611,0.184147,0.310929,0.184147,0.71132,0.245723,0.266373,0.245723,0.552397,0.276778,0.221065,0.276778,0.478744,0.330583,0.249351,0.330583,0.319081,0.302921,0.283785,0.302921,0.410776,0.11814,0.346472,0.11814,0.836876,0.156906,0.29355,0.156906,0.78349,0.083207,0.4109695,0.083207,0.880129,69.041753,134.018283,69.041753,0.0,0.142524,0.291033,0.142524,0.650272,0.193655,0.270971,0.193655,0.452299,0.219816,0.239184,0.219816,0.335887,0.276127,0.279156,0.276127,0.175482,0.257328,0.27812,0.257328,0.263964,0.100435,0.264279,0.100435,0.825754,0.126811,0.226689,0.126811,0.758774,0.074069,0.319594,0.074069,0.899654,57.686709,134.104981,57.686709,0.0,0.147264,0.367864,0.147264,0.338112,0.175315,0.348578,0.175315,0.283243,0.193129,0.457244,0.193129,0.23134,0.159125,0.488242,0.159125,0.281513,0.077056,0.278478,0.077056,0.583292,0.097219,0.2857788,0.097219,0.475779,0.057612,0.287136,0.057612,0.707612,45.212122,134.188046,45.212122,0.0,0.2759,0.1901843,0.2759,0.420662,0.310672,0.284319,0.310672,0.29041,0.302636,0.253651,0.302636,0.333169,0.110767,0.320527,0.110767,0.775828,0.14739,0.2740896,0.14739,0.686604,0.068479,0.393858,0.068479,0.845774,63.412627,134.063926,63.412627,0.0,0.337171,0.254669,0.337171,0.327731,0.316844,0.261004,0.316844,0.420662,0.11852,0.353719,0.11852,0.836629,0.161451,0.291194,0.161451,0.796342,0.078199,0.428699,0.078199,0.886802,69.358547,134.012361,69.358547,0.0,0.421535,0.148117,0.421535,0.580326,0.153282,0.442677,0.153282,0.915472,0.199755,0.377407,0.199755,0.895947,0.108233,0.512521,0.108233,0.924864,87.369172,133.887285,87.369172,0.0,0.148519,0.406282,0.148519,0.892486,0.191223,0.3500123,0.191223,0.846762,0.105382,0.473166,0.105382,0.90954,82.262444,133.929145,82.262444,0.0,0.076055,0.187455,0.076055,0.339595,0.040346,0.18838,0.040346,0.663618,31.600806,134.297583,31.600806,0.0,0.052525,0.2252,0.052525,0.798319,40.541308,134.229319,40.541308,0.0,23.368508,134.35875,23.368508,0.0,-7.30563e-15,9.328727e-15,-1.77021e-15,-9.300629e-15,-1.152042e-15,6.434574e-15,8.766756e-15,-3.989998e-15,5.451124e-15,1.208239e-15,1.494844e-14,3.108405e-16,-1.13459e-15,-1.352244e-15,3.093039e-16,-1.467712e-15,-1.714891e-15,3.195115e-16,-6.058755e-17,1.028452e-16,-2.138126e-16,6.717316e-17,-1.475175e-16,1.016817e-15,3.354267e-16,7.711742e-16,4.814076e-16,8.848856e-16,-4.978716e-16,2.651803e-16
std,1168.123923,2.887542,0.211821,0.158453,0.283517,0.263596,0.223505,0.197972,0.24342,0.189897,0.182977,0.198511,30.432382,0.0566,0.271831,0.142602,0.21628,0.380426,0.180254,0.120885,0.206748,0.389044,0.078373,0.231577,0.164789,0.027223,0.103727,0.181528,0.280779,0.093919,0.110491,0.118883,0.200652,0.288281,0.0,0.118272,0.170918,0.118272,0.46413,0.176562,0.20662,0.176562,0.453205,0.176793,0.187577,0.176793,0.497308,0.175172,0.16303,0.175172,0.49961,0.177945,0.201425,0.177945,0.466178,0.175564,0.212983,0.175564,0.492035,0.111015,0.208924,0.111015,0.369525,0.128539,0.184631,0.128539,0.411917,0.105046,0.2074966,0.105046,0.324851,33.178434,30.435986,33.178434,0.0,0.130122,0.191248,0.130122,0.476943,0.12289,0.188448,0.12289,0.497781,0.116641,0.175579,0.116641,0.472358,0.132389,0.179979,0.132389,0.380426,0.143747,0.195934,0.143747,0.440835,0.099274,0.167078,0.099274,0.379368,0.092797,0.166497,0.092797,0.42788,0.101604,0.163536,0.101604,0.300498,25.7615,30.425518,25.7615,0.0,0.144328,0.240001,0.144328,0.473125,0.15861,0.216178,0.15861,0.450629,0.153652,0.241762,0.153652,0.421741,0.116032,0.254895,0.116032,0.449792,0.10161,0.24119,0.10161,0.493075,0.099732,0.218692,0.099732,0.499475,0.110195,0.265631,0.110195,0.454916,38.620157,30.46912,38.620157,0.0,0.222949,0.1442885,0.222949,0.493726,0.206497,0.203762,0.206497,0.454008,0.221479,0.204724,0.221479,0.471404,0.12021,0.23556,0.12021,0.417087,0.141108,0.2096513,0.141108,0.463931,0.076794,0.257046,0.076794,0.36121,39.656076,30.419584,39.656076,0.0,0.180932,0.18953,0.180932,0.469444,0.188037,0.202809,0.188037,0.493726,0.111282,0.221408,0.111282,0.36975,0.140969,0.19386,0.140969,0.402767,0.083538,0.222243,0.083538,0.316874,32.349596,30.451838,32.349596,0.0,0.231726,0.113281,0.231726,0.493567,0.138462,0.2087,0.138462,0.278212,0.14526,0.195547,0.14526,0.305368,0.135942,0.212915,0.135942,0.263643,34.562145,30.412747,34.562145,0.0,0.150619,0.227603,0.150619,0.309803,0.155064,0.2074911,0.155064,0.360261,0.144796,0.240106,0.144796,0.286875,39.97625,30.395813,39.97625,0.0,0.09727,0.170262,0.09727,0.47363,0.06554,0.210606,0.06554,0.47253,27.524141,30.431441,27.524141,0.0,0.075072,0.197153,0.075072,0.401305,26.715474,30.43939,26.715474,0.0,27.999528,30.431607,27.999528,0.0,122.9942,69.44027,59.14307,40.36015,38.62141,36.6133,33.2218,32.91647,28.48274,26.50752,20.75149,0.5450737,0.4932492,0.4426273,0.4275965,0.4225538,0.4101197,0.3867268,0.3791355,0.372003,0.3661723,0.3503816,0.3402042,0.332474,0.3286413,0.3248159,0.3194684,0.3113954,0.3056855,0.2999086
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,40.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.3e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.8e-05,0.0,0.0,0.0,2e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.110223e-16,0.0,0.0,0.0,39.315789,0.0,0.0,0.0,6.9e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000114,0.0,0.0,0.0,6.8e-05,0.0,0.0,0.0,39.536055,0.0,0.0,0.0,0.000383,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000527,0.0,0.0,0.0,0.000777,0.0,0.0,0.0,0.000111,0.0,0.0,0.0,1.110223e-16,0.0,0.0,0.0,0.0,0.0,0.0,0.0,39.008829,0.0,0.0,0.0,2.220446e-16,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.1e-05,0.0,0.0,0.0,1.110223e-16,0.0,0.0,0.0,4.1e-05,0.0,0.0,0.0,39.541326,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.6e-05,0.0,0.0,0.0,9.8e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000243,0.0,0.0,0.0,39.428007,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000464,0.0,0.0,0.0,0.0,0.0,0.0,0.0,39.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.110223e-16,0.0,0.0,0.0,0.000249,0.0,0.0,0.0,39.051619,0.0,0.0,0.0,7.3e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,39.252893,0.0,0.0,0.0,6.1e-05,0.0,0.0,0.0,39.484971,0.0,0.0,0.0,39.0,0.0,0.0,-388.672,-111.0692,-171.4855,-127.9318,-115.3427,-121.2453,-126.0679,-129.503,-111.3049,-114.2452,-96.13118,-1.812536,-1.870845,-1.438806,-1.155813,-1.403131,-1.452708,-1.126972,-1.424939,-0.8547488,-1.077573,-1.389176,-1.247187,-1.361196,-1.263367,-1.177883,-1.231024,-1.297213,-1.228253,-1.037534
25%,1011.25,7.0,0.381579,0.344101,0.112032,0.247631,0.360279,0.538739,0.430235,0.114511,0.188647,0.081236,120.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.136452,0.090067,0.136452,0.0,0.047082,0.142144,0.047082,0.0,0.102787,0.11526,0.102787,0.0,0.139244,0.09127,0.139244,0.0,0.202091,0.091588,0.202091,0.0,0.171362,0.111318,0.171362,0.0,0.048278,0.173125,0.048278,1.0,0.076098,0.147216,0.076098,1.0,0.033529,0.2607847,0.033529,1.0,46.736842,119.263158,46.736842,0.0,0.045592,0.143632,0.045592,0.0,0.097291,0.119057,0.097291,0.0,0.136692,0.099277,0.136692,0.0,0.195963,0.131727,0.195963,0.0,0.163355,0.118178,0.163355,0.0,0.042258,0.137616,0.042258,1.0,0.068902,0.101518,0.068902,1.0,0.029346,0.218631,0.029346,1.0,41.09785,119.42415,41.09785,0.0,0.041032,0.156502,0.041032,0.0,0.049354,0.170186,0.049354,0.0,0.07374,0.257468,0.07374,0.0,0.071866,0.279791,0.071866,0.0,0.017826,0.078114,0.017826,0.0,0.027142,0.1114441,0.027142,0.0,0.011987,0.063277,0.011987,0.0,14.881754,119.222846,14.881754,0.0,0.089038,0.07283262,0.089038,0.0,0.144319,0.116185,0.144319,0.0,0.110526,0.09464,0.110526,0.0,0.033062,0.121623,0.033062,1.0,0.051795,0.1028609,0.051795,0.0,0.02568,0.166071,0.02568,1.0,31.130009,119.181118,31.130009,0.0,0.201302,0.105883,0.201302,0.0,0.163595,0.099871,0.163595,0.0,0.046622,0.168599,0.046622,1.0,0.073763,0.12898,0.073763,1.0,0.033438,0.256558,0.033438,1.0,46.01141,119.16527,46.01141,0.0,0.242391,0.058992,0.242391,0.0,0.063933,0.294177,0.063933,1.0,0.104968,0.226763,0.104968,1.0,0.045919,0.383698,0.045919,1.0,63.963271,119.148078,63.963271,0.0,0.051649,0.215484,0.051649,1.0,0.082,0.171502,0.082,1.0,0.037986,0.287673,0.037986,1.0,53.337285,119.12483,53.337285,0.0,0.023409,0.065124,0.023409,0.0,0.010843,0.048077,0.010843,0.0,14.30125,119.568107,14.30125,0.0,0.017337,0.086097,0.017337,1.0,23.744005,119.556523,23.744005,0.0,10.039664,119.78187,10.039664,0.0,-95.73937,-52.67799,-41.1651,-25.57952,-23.06539,-20.81497,-20.58495,-18.65693,-18.27341,-15.64752,-13.4418,-0.3688912,-0.3327176,-0.2930228,-0.2634863,-0.286466,-0.2780524,-0.2693093,-0.2457235,-0.2996339,-0.2576071,-0.2326647,-0.2327916,-0.217543,-0.2185583,-0.2198702,-0.2122103,-0.2043557,-0.1986672,-0.17689
50%,2022.5,8.0,0.526316,0.41662,0.231268,0.450772,0.529204,0.669539,0.63978,0.176306,0.268539,0.118155,120.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.213219,0.191911,0.213219,1.0,0.119256,0.284066,0.119256,1.0,0.21414,0.235355,0.21414,1.0,0.261254,0.19131,0.261254,0.0,0.321311,0.198264,0.321311,0.0,0.288772,0.235855,0.288772,0.0,0.087411,0.337764,0.087411,1.0,0.129447,0.276081,0.129447,1.0,0.058362,0.4183971,0.058362,1.0,66.947368,120.0,66.947368,0.0,0.096648,0.268827,0.096648,1.0,0.177755,0.243467,0.177755,0.0,0.215831,0.207633,0.215831,0.0,0.274207,0.261655,0.274207,0.0,0.249894,0.251675,0.249894,0.0,0.071491,0.251516,0.071491,1.0,0.106001,0.196904,0.106001,1.0,0.047157,0.303938,0.047157,1.0,55.218718,120.0,55.218718,0.0,0.101474,0.345779,0.101474,0.0,0.127062,0.336344,0.127062,0.0,0.142172,0.479534,0.142172,0.0,0.128386,0.496785,0.128386,0.0,0.043985,0.202985,0.043985,1.0,0.062335,0.2297548,0.062335,0.0,0.027911,0.178816,0.027911,1.0,30.994379,119.999867,30.994379,0.0,0.219348,0.1642473,0.219348,0.0,0.28402,0.253547,0.28402,0.0,0.27054,0.20425,0.27054,0.0,0.072399,0.270036,0.072399,1.0,0.107557,0.2207003,0.107557,1.0,0.04961,0.374472,0.04961,1.0,58.639908,120.0,58.639908,0.0,0.333621,0.214249,0.333621,0.0,0.310338,0.218462,0.310338,0.0,0.087306,0.33522,0.087306,1.0,0.128148,0.268584,0.128148,1.0,0.058592,0.436518,0.058592,1.0,69.666718,119.908753,69.666718,0.0,0.427004,0.124596,0.427004,1.0,0.111542,0.457065,0.111542,1.0,0.168446,0.382508,0.168446,1.0,0.074689,0.541665,0.074689,1.0,86.608177,120.0,86.608177,0.0,0.098888,0.413849,0.098888,1.0,0.152638,0.348455,0.152638,1.0,0.068132,0.497452,0.068132,1.0,81.61586,120.0,81.61586,0.0,0.04632,0.141486,0.04632,0.0,0.02082,0.107832,0.02082,1.0,23.646902,120.0,23.646902,0.0,0.03098,0.16971,0.03098,1.0,34.919424,120.0,34.919424,0.0,15.350679,120.0,15.350679,0.0,-7.027652,-18.63096,-2.670537,-0.3057021,-4.507955,-4.80967,-0.9958473,-0.7483429,-0.7272713,-0.03050716,-0.4089575,-0.00623636,-0.004672502,0.005274793,-0.03501363,0.009163934,-0.01182609,-0.02075138,0.002499268,-0.1329774,-0.02439341,-0.01437039,-0.002850868,-0.009849603,-0.007281657,-0.002393234,-0.000612775,0.0121552,0.002546144,-0.01729942
75%,3033.75,10.0,0.657895,0.500899,0.552802,0.680469,0.681067,0.786531,0.803927,0.29884,0.365878,0.161997,152.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.291503,0.323598,0.291503,1.0,0.286462,0.450938,0.286462,1.0,0.354333,0.384544,0.354333,1.0,0.391773,0.320822,0.391773,1.0,0.442422,0.367092,0.442422,1.0,0.415315,0.413706,0.415315,1.0,0.150529,0.494854,0.150529,1.0,0.200123,0.419084,0.200123,1.0,0.090362,0.556251,0.090362,1.0,90.0,151.592105,90.0,0.0,0.211466,0.395884,0.211466,1.0,0.271606,0.392424,0.271606,1.0,0.295211,0.340277,0.295211,1.0,0.346241,0.399855,0.346241,0.0,0.336481,0.400162,0.336481,1.0,0.122949,0.358783,0.122949,1.0,0.158959,0.312882,0.158959,1.0,0.071155,0.399818,0.071155,1.0,70.558884,151.638994,70.558884,0.0,0.206809,0.555424,0.206809,1.0,0.264396,0.501269,0.264396,1.0,0.289657,0.647301,0.289657,0.0,0.224519,0.701533,0.224519,1.0,0.096891,0.429181,0.096891,1.0,0.138556,0.4238117,0.138556,1.0,0.065435,0.488317,0.065435,1.0,69.179228,151.885097,69.179228,0.0,0.425962,0.2724867,0.425962,1.0,0.453563,0.419043,0.453563,1.0,0.461597,0.355989,0.461597,1.0,0.141942,0.49504,0.141942,1.0,0.19624,0.4103904,0.19624,1.0,0.086508,0.601704,0.086508,1.0,89.648462,151.731522,89.648462,0.0,0.459664,0.366172,0.459664,1.0,0.453581,0.372383,0.453581,1.0,0.152456,0.511627,0.152456,1.0,0.201362,0.425412,0.201362,1.0,0.091436,0.586747,0.091436,1.0,91.771233,151.588363,91.771233,0.0,0.597281,0.212057,0.597281,1.0,0.197337,0.598561,0.197337,1.0,0.254018,0.521563,0.254018,1.0,0.109277,0.66713,0.109277,1.0,111.105095,151.417215,111.105095,0.0,0.191518,0.585308,0.191518,1.0,0.254482,0.5071929,0.254482,1.0,0.105648,0.66879,0.105648,1.0,109.706306,151.508957,109.706306,0.0,0.089521,0.248982,0.089521,1.0,0.040574,0.246925,0.040574,1.0,39.556622,151.874841,39.556622,0.0,0.053466,0.290306,0.053466,1.0,50.178558,151.80597,50.178558,0.0,22.498492,151.912237,22.498492,0.0,87.84358,40.23905,33.76015,25.56389,16.21593,12.94016,19.63346,17.22486,17.2128,15.68059,12.71048,0.3658082,0.316483,0.3061084,0.2513082,0.2927919,0.277206,0.2619235,0.263109,0.3909325,0.2401957,0.2267318,0.2256634,0.2128364,0.2163098,0.2178044,0.2031693,0.2106144,0.1970278,0.1599871
max,4045.0,10.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,220.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.842548,1.0,0.842548,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.994298,1.0,0.994298,1.0,0.921053,1.0,0.921053,1.0,0.890378,0.973684,0.890378,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,208.0,219.394737,208.0,0.0,0.964195,0.991171,0.964195,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.979616,1.0,0.979616,1.0,0.944988,1.0,0.944988,1.0,1.0,0.999573,1.0,1.0,0.943012,1.0,0.943012,1.0,0.945146,1.0,0.945146,1.0,192.0,219.659969,192.0,0.0,0.924222,0.983985,0.924222,1.0,0.903439,1.0,0.903439,1.0,0.850599,1.0,0.850599,1.0,0.732629,1.0,0.732629,1.0,0.930415,0.97857,0.930415,1.0,1.0,0.9929697,1.0,1.0,1.0,1.0,1.0,1.0,192.0,219.920696,192.0,0.0,1.0,0.8496157,1.0,1.0,1.0,1.0,1.0,1.0,0.983583,0.996306,0.983583,1.0,0.926729,1.0,0.926729,1.0,1.0,1.0,1.0,1.0,0.869862,1.0,0.869862,1.0,208.0,219.819983,208.0,0.0,1.0,1.0,1.0,1.0,0.954504,1.0,0.954504,1.0,0.959881,1.0,0.959881,1.0,1.0,1.0,1.0,1.0,0.811908,1.0,0.811908,1.0,185.130503,219.79637,185.130503,0.0,1.0,0.808457,1.0,1.0,1.0,0.959834,1.0,1.0,0.946596,0.991962,0.946596,1.0,1.0,1.0,1.0,1.0,208.0,219.786776,208.0,0.0,1.0,0.998358,1.0,1.0,1.0,1.0,1.0,1.0,0.989612,1.0,0.989612,1.0,218.918515,220.0,218.918515,0.0,0.996113,0.96627,0.996113,1.0,0.827783,1.0,0.827783,1.0,208.0,219.837625,208.0,0.0,0.826666,1.0,0.826666,1.0,220.0,219.823937,220.0,0.0,208.0,219.980909,208.0,0.0,430.427,344.9739,276.3462,156.3338,266.3832,239.0963,205.7707,180.8422,181.2115,132.7488,135.054,2.253403,3.181155,2.225019,1.337326,1.553355,1.570804,1.551564,1.609956,1.003781,1.310329,1.642654,1.276857,1.632823,1.444511,1.790409,1.378429,1.535372,1.258885,1.225825


## Parameter Setting

In [6]:
features = [col for col in df_train.columns if col not in [Config.row_id, Config.target]]

In [None]:
# bool値をintに変換
col_list = [col for col in df_train.columns if df_train[col].dtypes == bool]

for df in [df_train, df_test]:
    df[col_list] = df[col_list] * 1

df_train.dtypes

## Validation data Setting

In [7]:
X_test = df_test[features]

'''
for c in TARGET_ENCODING_CATEGORY:
    data_tmp = pd.DataFrame({c: df_train[c], 'target': df_train[TARGET]})
    target_mean = data_tmp.groupby(c)['target'].mean()
    X_test.loc[:, c] = X_test[c].map(target_mean)
'''

X_test = (X_test.values).astype(np.float32)
X_test.shape

(4046, 283)

## Modeling

### Multi Layer Perceptron
- 隠れ層3層のMLP
- kernel_initializerにHeの初期化を採用
- Batch Normalizationを採用
- 活性化関数にReLUを採用
- Optimizerを採用（SGD、Adamなど。）
- Dropoutを採用
  - DropoutとBatchNormalizationを同時に使うと学習がうまくできない場合がある。
  - その場合、Dropoutを外す
- モデルの順序は、BatchNormalization、活性化関数、Dropoutであることに注意

In [31]:
def root_mean_squared_error(y_true, y_pred):
    return tf.keras.backend.sqrt(tf.keras.backend.mean(tf.keras.backend.square(y_pred - y_true)))

def setup_model():
    activation = 'relu'
    kernel_initializer = 'he_normal'

    model = Sequential()

    model.add(Dense(256, kernel_initializer=kernel_initializer))
    model.add(BatchNormalization())
    model.add(Activation(activation))
    model.add(Dropout(0.25))

    model.add(Dense(192, kernel_initializer=kernel_initializer))
    model.add(BatchNormalization())
    model.add(Activation(activation))
    model.add(Dropout(0.25))

    model.add(Dense(128, kernel_initializer=kernel_initializer))
    model.add(BatchNormalization())
    model.add(Activation(activation))
    # model.add(Dropout(0.25))

    model.add(Dense(64, kernel_initializer=kernel_initializer))
    model.add(BatchNormalization())
    model.add(Activation(activation))
    # model.add(Dropout(0.25))

    model.add(Dense(32, kernel_initializer=kernel_initializer))
    model.add(BatchNormalization())
    model.add(Activation(activation))
    # model.add(Dropout(0.25))

    model.add(Dense(11, activation='softmax'))

    optimizer = optimizers.Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.999, amsgrad=True)
    # optimizer = optimizers.SGD(learning_rate=0.001)

    # model.compile(optimizer=optimizer, loss=root_mean_squared_error, metrics=[root_mean_squared_error])
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['categorical_crossentropy'])

    return model


def setup_callbacks():
    es = EarlyStopping(monitor='val_loss', patience=10, verbose=1)
    lr = ReduceLROnPlateau(monitor="val_loss", factor=0.7, patience=5, verbose=1)
    callbacks = [es, lr]

    return callbacks


mlp_param = {
    'epochs': 300,
    'batch_size': 100,
    'verbose': 1,
}


### Training & Validation with TargetEncoding

In [32]:
np.random.seed(Config.random_seed)
tf.random.set_seed(Config.random_seed)

# Create a numpy array to store test predictions
test_predictions = np.zeros((len(df_test), Config.n_folds))

# Create a numpy array to store out of folds predictions
oof_predictions = np.zeros(len(df_train))

feature_importance_df = pd.DataFrame(index=features)
y_valids, val_preds =[],[]

kfold = StratifiedKFold(n_splits=Config.n_folds, shuffle=True, random_state=Config.random_seed)

for fold, (train_idx, valid_idx) in enumerate(kfold.split(df_train, df_train[Config.target])):

    print(' ')
    print('-'*50)
    print(f'Training fold {fold+1} with {len(features)} features...')

    X_train, X_val = df_train[features].iloc[train_idx], df_train[features].iloc[valid_idx]
    y_train, y_val = df_train[Config.target].iloc[train_idx], df_train[Config.target].iloc[valid_idx]

    # Over Sampling
    sm = SMOTE(random_state=Config.random_seed)
    X_train, y_train = sm.fit_resample(X_train, y_train)
    # print(y_train.value_counts())

    dummy_y_train = np_utils.to_categorical(y_train)
    dummy_y_val = np_utils.to_categorical(y_val)

    # training
    model = setup_model()
    callbacks = setup_callbacks()
    hist = model.fit(X_train, dummy_y_train, validation_data=(X_val, dummy_y_val), epochs=mlp_param['epochs'], batch_size=mlp_param['batch_size'], callbacks=callbacks, verbose=mlp_param['verbose'])

    print(f'================================== training {fold+1} fin. ==================================')

    # Predict validation data
    print(f'================================== validation-data predicting ... ==================================')
    val_pred = model.predict(X_val)
    val_pred = np.argmax(val_pred, axis=1)
    oof_predictions[valid_idx] = val_pred

    # Predict test data
    print(f'================================== test-data predicting ... ==================================')
    test_pred = model.predict(df_test[features])
    test_pred = np.argmax(test_pred, axis=1)

    test_predictions[:, fold] += test_pred

    # save results
    y_valids.append(y_val)
    val_preds.append(val_pred)
    # feature_importance_df["Importance_Fold"+str(fold+1)]=model.feature_importance(importance_type='gain')

    # Compute fold metric
    val_pred = pd.DataFrame(data={'prediction': val_pred})
    y_val = pd.DataFrame(data={'target': y_val.reset_index(drop=True)})
    score = f1_score(y_val, val_pred, average='macro')

    print(f'Fold {fold+1} CV result')
    print(f'metric : {score}')

    del X_train, X_val, y_train, y_val
    _ = gc.collect()

# Compute out of folds metric
oof_predictions = pd.DataFrame(data={'prediction': oof_predictions})
y_true = pd.DataFrame(data={Config.target: df_train[Config.target]})

print(' ')
print('-'*50)
print(f'TOTAL socre : {f1_score(df_train[Config.target], oof_predictions["prediction"], average="macro")}')
print('-'*50)

# Create a dataframe to store out of folds predictions
oof_df = pd.DataFrame({Config.row_id: df_train[Config.row_id], Config.target: df_train[Config.target], 'prediction': oof_predictions['prediction']})

# Create a dataframe to store test prediction
test_predictions, _ = stats.mode(test_predictions, axis=1)
test_predictions = test_predictions.reshape(-1)

test_df = pd.DataFrame({Config.row_id: df_test[Config.row_id], Config.target: test_predictions})

 
--------------------------------------------------
Training fold 1 with 283 features...
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.006999999843537807.
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.004899999825283885.
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 00030: ReduceLROnPlateau reducing learning rate to 0.0034300000406801696.
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 00035: ReduceLROnPlateau reducing learning rate to 0.002401000028476119.
Epoch 00035: early stopping
Fold 1 CV result
metric : 0.36272237466195895
 
--------------------------------------------------
Training fold 2 with

KeyboardInterrupt: 

In [None]:
# Save results
oof_df_tmp = oof_df.drop(columns=[Config.target])
oof_df_tmp.columns = [Config.row_id, f'nb{Config.NB}']
oof_df_tmp.to_csv(Config.interim_dir + f'nb{Config.NB}.csv', index=False)
oof_df_tmp

## 結果の可視化

In [29]:
cm = confusion_matrix(oof_df[Config.target], oof_df['prediction'], normalize='true')

names = [f'Target_{i}' for i in range(11)]

fig = ff.create_annotated_heatmap(cm, x=names, y=names)
fig.update_layout(
    yaxis_title='True Label',
    xaxis_title='Pred Label',
)
fig.show()

In [30]:
fig = go.Figure(layout=plotly_template['layout'])
fig.add_trace(
    go.Histogram(
        x=test_df[Config.target],
        name=f'Prediction',
        histnorm='probability',
        marker=dict(color=color_palette['Bin'][0]),
        #line=dict(color='black')
    ),
)

fig.add_trace(
    go.Histogram(
        x=df_train[Config.target],
        name=f'Train',
        histnorm='probability',
        marker=dict(color=color_palette['Bin'][1]),
        opacity=0.5
        #line=dict(color='black')
    ),
)

fig.update_layout(
    title='Prediction Distribution',
    barmode='overlay',
    uniformtext_minsize=15,
    uniformtext_mode='hide',
    width=700)

fig.show()

## Submission

In [None]:
Config.NB

In [None]:
test_df.to_csv(Config.submission_dir + f'nb{Config.NB}.csv', index=False)

## 検証メモ

In [None]:
df_train[features].dtypes