# Music Genre Prediction

In [24]:
class Config:
    NB = '001'
    dataset_NB = '101'
    emsemble_NB = ['nb301', 'nb302', 'nb303']

    raw_data_dir = '../data/raw/'
    processed_data_dir = '../data/processed/'
    interim_dir = '../data/interim/'
    submission_dir = '../data/submission/'

    random_seed = 42
    n_folds = 5

    row_id = 'index'
    target = 'genre'

## Import libralies

In [2]:
import gc
import warnings
warnings.filterwarnings('ignore')

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
import itertools

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
sns.set(style='white', context='notebook', palette='deep')

In [3]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

plotly_template = dict(
    layout=go.Layout(
        template='plotly_dark',
        font=dict(
            family="Franklin Gothic",
            size=12
        ),
        height=500,
        width=1000,
    )
)

color_palette = {
    'Bin': ['#016CC9','#E876A3'],
    'Cat5': ['#E876A3', '#E0A224', '#63B70D', '#6BCFF6', '#13399E'],
    'Cat10': ['#E876A3', '#E0A224', '#63B70D', '#6BCFF6', '#13399E', '#E876A3', '#E0A224', '#63B70D', '#6BCFF6', '#13399E'],
}

## Load data
- data数
  - train: 4046x14
  - test: 4046x13
- columns
  - targetは「genre」で、int
  - tempo　と　region がobjectな模様
- 欠損
  - train
    - positiveness: 10
    - danceability: 8
    - liveness: 3
    - speechiness: 8
    - instrumentalness: 1
    - region: 370(unknownという名称で入っている)
  - test
    - acousticness: 1
    - positiveness: 14
    - danceability: 11
    - energy: 1
    - liveness: 6
    - speechiness: 11
    - instrumentalness: 2
    - region: 320(unknownという名称で入っている)

### Insight:
- 欠損はtrainにもtestにも存在する。およそ均等に存在する。
- regionでunknownになっている数が最も多い。
  - 欠損と捉えて補完するのか、このまま扱うかを検討する必要がありそう。

In [4]:
df_train = pd.read_csv(Config.raw_data_dir + 'train.csv')
df_test = pd.read_csv(Config.raw_data_dir + 'test.csv')

In [7]:
df_train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4046 entries, 0 to 4045
Data columns (total 14 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   index             4046 non-null   int64  
 1   genre             4046 non-null   int64  
 2   popularity        4046 non-null   int64  
 3   duration_ms       4046 non-null   int64  
 4   acousticness      4046 non-null   float64
 5   positiveness      4036 non-null   float64
 6   danceability      4038 non-null   float64
 7   loudness          4046 non-null   float64
 8   energy            4046 non-null   float64
 9   liveness          4043 non-null   float64
 10  speechiness       4038 non-null   float64
 11  instrumentalness  4045 non-null   float64
 12  tempo             4046 non-null   object 
 13  region            4046 non-null   object 
dtypes: float64(8), int64(4), object(2)
memory usage: 442.7+ KB


In [8]:
df_test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4046 entries, 0 to 4045
Data columns (total 13 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   index             4046 non-null   int64  
 1   popularity        4046 non-null   int64  
 2   duration_ms       4046 non-null   int64  
 3   acousticness      4045 non-null   float64
 4   positiveness      4032 non-null   float64
 5   danceability      4035 non-null   float64
 6   loudness          4046 non-null   float64
 7   energy            4045 non-null   float64
 8   liveness          4040 non-null   float64
 9   speechiness       4035 non-null   float64
 10  instrumentalness  4044 non-null   float64
 11  tempo             4046 non-null   object 
 12  region            4046 non-null   object 
dtypes: float64(8), int64(3), object(2)
memory usage: 411.0+ KB


In [11]:
display(len(df_train[df_train['region'] == 'unknown']))
display(len(df_test[df_test['region'] == 'unknown']))

370

326

In [13]:
df_train.describe()

Unnamed: 0,index,genre,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness
count,4046.0,4046.0,4046.0,4046.0,4046.0,4036.0,4038.0,4046.0,4046.0,4043.0,4038.0,4045.0
mean,2022.5,7.28176,41.056105,242141.0,0.346455,0.4641,0.504347,-7.715659,0.603663,0.265986,0.198655,0.214336
std,1168.123923,2.887542,16.165708,85202.41,0.241004,0.225052,0.158415,4.10964,0.20102,0.155769,0.083557,0.154281
min,0.0,0.0,0.0,5998.0,0.0,0.0,0.013839,-37.820457,0.003383,0.0,0.0,0.0
25%,1011.25,7.0,31.0,204442.0,0.149705,0.276384,0.392581,-9.775363,0.462137,0.168527,0.148698,0.143295
50%,2022.5,8.0,42.0,235873.5,0.250711,0.450211,0.510993,-7.18946,0.634078,0.218486,0.18319,0.171708
75%,3033.75,10.0,52.0,272402.0,0.523088,0.644786,0.617371,-4.876553,0.768768,0.317773,0.224999,0.205446
max,4045.0,10.0,82.0,2135773.0,1.0,0.989661,1.0,0.0,1.0,1.0,0.886806,1.0


In [14]:
df_train

Unnamed: 0,index,genre,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,tempo,region
0,0,10,11,201094,0.112811,0.157247,0.187841,-1.884852,0.893918,0.363568,0.390108,0.888884,121-152,region_H
1,1,8,69,308493,0.101333,0.346563,0.554444,-5.546495,0.874409,0.193892,0.161497,0.123910,153-176,region_I
2,2,3,43,197225,0.496420,0.265391,0.457642,-9.255670,0.439933,0.217146,0.369057,0.166470,64-76,region_E
3,3,10,45,301092,0.165667,0.245533,0.356578,-5.088788,0.868704,0.377025,0.226677,0.175399,177-192,region_C
4,4,3,57,277348,0.190720,0.777578,0.830479,-3.933896,0.650149,0.169323,0.222488,0.226030,97-120,unknown
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4041,4041,10,38,246309,0.128795,0.329718,0.575830,-5.252543,0.509531,0.193781,0.187331,0.158197,121-152,region_P
4042,4042,5,23,208734,0.778732,0.228757,0.394283,-14.225700,0.322871,0.253108,0.141218,0.295608,121-152,region_D
4043,4043,10,30,407016,0.402050,0.462374,0.672265,-10.711253,0.646571,0.228189,0.152270,0.163483,97-120,region_E
4044,4044,10,25,204957,0.208096,0.465511,0.523514,-7.855946,0.508910,0.247820,0.202238,0.210184,77-96,region_R


### 文字列カテゴリの確認
- tempoに含まれる文字列の種類はtrainとtestで同じ
- regionに含まれる文字列の種類は、region_Mがtrainのみに存在し、testには無い

In [20]:
display(sorted(df_train['tempo'].unique()))
display(sorted(df_test['tempo'].unique()))

['0-40',
 '121-152',
 '153-176',
 '177-192',
 '193-208',
 '209-220',
 '41-50',
 '51-56',
 '57-63',
 '64-76',
 '77-96',
 '97-120']

['0-40',
 '121-152',
 '153-176',
 '177-192',
 '193-208',
 '209-220',
 '41-50',
 '51-56',
 '57-63',
 '64-76',
 '77-96',
 '97-120']

In [21]:
display(sorted(df_train['region'].unique()))
display(sorted(df_test['region'].unique()))

['region_A',
 'region_B',
 'region_C',
 'region_D',
 'region_E',
 'region_F',
 'region_G',
 'region_H',
 'region_I',
 'region_J',
 'region_K',
 'region_L',
 'region_M',
 'region_N',
 'region_O',
 'region_P',
 'region_Q',
 'region_R',
 'region_S',
 'region_T',
 'unknown']

['region_A',
 'region_B',
 'region_C',
 'region_D',
 'region_E',
 'region_F',
 'region_G',
 'region_H',
 'region_I',
 'region_J',
 'region_K',
 'region_L',
 'region_N',
 'region_O',
 'region_P',
 'region_Q',
 'region_R',
 'region_S',
 'region_T',
 'unknown']

In [83]:
df_train

Unnamed: 0,index,genre,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,tempo,region
0,0,10,11,201094,0.112811,0.157247,0.187841,-1.884852,0.893918,0.363568,0.390108,0.888884,121-152,region_H
1,1,8,69,308493,0.101333,0.346563,0.554444,-5.546495,0.874409,0.193892,0.161497,0.123910,153-176,region_I
2,2,3,43,197225,0.496420,0.265391,0.457642,-9.255670,0.439933,0.217146,0.369057,0.166470,64-76,region_E
3,3,10,45,301092,0.165667,0.245533,0.356578,-5.088788,0.868704,0.377025,0.226677,0.175399,177-192,region_C
4,4,3,57,277348,0.190720,0.777578,0.830479,-3.933896,0.650149,0.169323,0.222488,0.226030,97-120,unknown
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4041,4041,10,38,246309,0.128795,0.329718,0.575830,-5.252543,0.509531,0.193781,0.187331,0.158197,121-152,region_P
4042,4042,5,23,208734,0.778732,0.228757,0.394283,-14.225700,0.322871,0.253108,0.141218,0.295608,121-152,region_D
4043,4043,10,30,407016,0.402050,0.462374,0.672265,-10.711253,0.646571,0.228189,0.152270,0.163483,97-120,region_E
4044,4044,10,25,204957,0.208096,0.465511,0.523514,-7.855946,0.508910,0.247820,0.202238,0.210184,77-96,region_R


In [89]:
region_list = sorted(df_train['region'].unique())

fig = make_subplots(
    rows=4,
    cols=6,
    #start_cell='bottom-left', # どのセルを起点とするか
    subplot_titles=region_list,
    shared_xaxes=True, # x軸を共有する場合
    shared_yaxes=False, # y軸を共有する場合
)
for idx, region in enumerate(region_list):
    fig.add_trace(
    go.Histogram(
        x=df_train[df_train['region'] == region][Config.target],
        name=f'Target',
        #histnorm='probability',
        marker=dict(color=color_palette['Bin'][0]),
        #line=dict(color='black')
    ),
    row=int(idx/6) + 1,
    col=idx%6 + 1

)

fig.update_layout(
    title='Target Distribution',
    uniformtext_minsize=15,
    uniformtext_mode='hide',
    height=1500,
)

fig.show()

In [94]:
target_list = [i for i in range(11)]

fig = make_subplots(
    rows=3,
    cols=4,
    #start_cell='bottom-left', # どのセルを起点とするか
    subplot_titles=target_list,
    shared_xaxes=True, # x軸を共有する場合
    shared_yaxes=False, # y軸を共有する場合
)

for idx, i in enumerate(target_list):
    fig.add_trace(
    go.Histogram(
        x=df_train[df_train[Config.target] == i]['region'],
        name=f'Target',
        #histnorm='probability',
        marker=dict(color=color_palette['Bin'][0]),
        #line=dict(color='black')
    ),
    row=int(idx/4) + 1,
    col=idx%4 + 1

)

fig.update_layout(
    title='Target Distribution',
    uniformtext_minsize=15,
    uniformtext_mode='hide',
    height=1500,
)

fig.show()

### 数値カテゴリの確認
- trainとtestでほとんど分布は変わらなさそう
- popularityで、80くらいにひとつの山がある？

In [22]:
col_list = [col for col in df_test.columns if df_test[col].dtype != object]

fig = make_subplots(
    rows=3,
    cols=4,
    #start_cell='bottom-left', # どのセルを起点とするか
    subplot_titles=col_list,
    shared_xaxes=False, # x軸を共有する場合
    shared_yaxes=False, # y軸を共有する場合
)

for idx, col in enumerate(col_list):
    fig.add_trace(
        go.Histogram(
            x=df_train[col],
            name=f'{col}(Train)',
            histnorm='probability',
            marker=dict(color=color_palette['Bin'][0]),
            #line=dict(color='black')
        ),
        row=int(idx/4) + 1,
        col=idx%4 + 1
    )
    fig.add_trace(
        go.Histogram(
            x=df_test[col],
            name=f'{col}(Test)',
            histnorm='probability',
            marker=dict(color=color_palette['Bin'][1]),
            opacity=0.5
            #line=dict(color='black')
        ),
        row=int(idx/4) + 1,
        col=idx%4 + 1
    )


fig.update_layout(
    barmode='overlay',
    width=1500,
    height=1500,
)
fig.show()

## Targetの分布
- 8, 10に大きく偏っている
- 0, 4, 6, 9が非常に少ないので、どう予測させるかが重要な気がする

In [30]:
df_target_num = df_train[Config.target].value_counts().sort_index(ascending=False)

fig = go.Figure(layout=plotly_template['layout'])

fig.add_trace(
    go.Pie(
        labels=df_target_num.index,
        values=df_target_num,
        #hole=.45,
        showlegend=True,
        sort=False,
        marker=dict(
            colors=color_palette['Cat10'],
            line=dict(color=color_palette['Cat10'], width=2.5)
        ),
        hovertemplate="%{label} Genre: %{value:.2f}",
    )
)

fig.update_layout(
    title='Target Distribution',
    legend=dict(traceorder='reversed', y=1.05, x=0),
    uniformtext_minsize=15,
    uniformtext_mode='hide',
    width=700)

fig.show()

In [32]:
fig = go.Figure(layout=plotly_template['layout'])
fig.add_trace(
    go.Histogram(
        x=df_train[Config.target],
        name=f'Target',
        #histnorm='probability',
        marker=dict(color=color_palette['Bin'][0]),
        #line=dict(color='black')
    ),
)

fig.update_layout(
    title='Target Distribution',
    uniformtext_minsize=15,
    uniformtext_mode='hide',
    width=700)

fig.show()

## Targetと各変数の分布
- popularity
  - Target_0,9 の平均値が高い
  - Target_5が若干低い
- duration_ms
  - Target_10にのみめっちゃ長い曲がありそう。
    - 一定の長さ以上であるかという特徴量を加えることで、Target_10を特定するのに役立ちそう
- acousticness
  - 高めのTargetと低めのターゲットで分かれているような気がする
  - Target_4,6は平均値が高い
- energy
  - Target_4,6,7が若干低いかな
- speechiness
  - Target_3の平均値が高い
- instrumentalness
  - Target_6の分布が他と異なっている
    - 平均値も高い
  - Target_10も高い値を取りやすいようだが、Target_6の分類には十分寄与しそう

In [76]:
df_train.groupby(Config.target).describe()

Unnamed: 0_level_0,index,index,index,index,index,index,index,index,popularity,popularity,popularity,popularity,popularity,popularity,popularity,popularity,duration_ms,duration_ms,duration_ms,duration_ms,duration_ms,duration_ms,duration_ms,duration_ms,acousticness,acousticness,acousticness,acousticness,acousticness,acousticness,acousticness,acousticness,positiveness,positiveness,positiveness,positiveness,positiveness,positiveness,positiveness,positiveness,danceability,danceability,danceability,danceability,danceability,danceability,danceability,danceability,loudness,loudness,loudness,loudness,loudness,loudness,loudness,loudness,energy,energy,energy,energy,energy,energy,energy,energy,liveness,liveness,liveness,liveness,liveness,liveness,liveness,liveness,speechiness,speechiness,speechiness,speechiness,speechiness,speechiness,speechiness,speechiness,instrumentalness,instrumentalness,instrumentalness,instrumentalness,instrumentalness,instrumentalness,instrumentalness,instrumentalness
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max
genre,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2,Unnamed: 44_level_2,Unnamed: 45_level_2,Unnamed: 46_level_2,Unnamed: 47_level_2,Unnamed: 48_level_2,Unnamed: 49_level_2,Unnamed: 50_level_2,Unnamed: 51_level_2,Unnamed: 52_level_2,Unnamed: 53_level_2,Unnamed: 54_level_2,Unnamed: 55_level_2,Unnamed: 56_level_2,Unnamed: 57_level_2,Unnamed: 58_level_2,Unnamed: 59_level_2,Unnamed: 60_level_2,Unnamed: 61_level_2,Unnamed: 62_level_2,Unnamed: 63_level_2,Unnamed: 64_level_2,Unnamed: 65_level_2,Unnamed: 66_level_2,Unnamed: 67_level_2,Unnamed: 68_level_2,Unnamed: 69_level_2,Unnamed: 70_level_2,Unnamed: 71_level_2,Unnamed: 72_level_2,Unnamed: 73_level_2,Unnamed: 74_level_2,Unnamed: 75_level_2,Unnamed: 76_level_2,Unnamed: 77_level_2,Unnamed: 78_level_2,Unnamed: 79_level_2,Unnamed: 80_level_2,Unnamed: 81_level_2,Unnamed: 82_level_2,Unnamed: 83_level_2,Unnamed: 84_level_2,Unnamed: 85_level_2,Unnamed: 86_level_2,Unnamed: 87_level_2,Unnamed: 88_level_2
0,32.0,1965.9375,1335.689092,20.0,794.25,1836.0,3406.75,4026.0,32.0,62.375,16.676717,11.0,65.0,67.0,68.25,82.0,32.0,228303.5625,28240.296733,183305.0,204670.75,224816.0,255958.25,274545.0,32.0,0.278668,0.169594,0.047048,0.175372,0.223835,0.367339,0.698777,32.0,0.559857,0.202935,0.147544,0.386968,0.619958,0.723526,0.940727,32.0,0.556328,0.117867,0.336912,0.45335,0.564099,0.638108,0.790989,32.0,-6.283481,2.995615,-13.970634,-7.851498,-6.596818,-4.043205,-0.769015,32.0,0.616083,0.161809,0.341714,0.470487,0.644423,0.763875,0.840018,32.0,0.224763,0.110404,0.061512,0.144804,0.200602,0.315316,0.465508,32.0,0.182526,0.049571,0.066649,0.137741,0.184146,0.225202,0.288472,32.0,0.165733,0.035955,0.108785,0.141346,0.166347,0.18593,0.253305
1,205.0,2132.95122,1213.947021,24.0,1044.0,2173.0,3227.0,4020.0,205.0,41.039024,15.316544,6.0,28.0,42.0,53.0,72.0,205.0,248802.326829,73429.66485,37491.0,210014.0,238734.0,272598.0,635217.0,205.0,0.251741,0.194464,0.023088,0.126788,0.175807,0.309016,0.92916,205.0,0.484411,0.23818,0.0,0.277062,0.479517,0.683069,0.941498,205.0,0.558864,0.174916,0.041755,0.460784,0.569628,0.686807,1.0,205.0,-7.774529,4.360201,-34.237218,-9.707902,-7.285388,-4.861342,0.0,205.0,0.651407,0.180889,0.04472,0.53818,0.685711,0.795094,0.917778,205.0,0.271373,0.156211,0.036888,0.164937,0.22177,0.349227,0.890162,205.0,0.18453,0.066419,0.034162,0.142933,0.177153,0.21282,0.467864,205.0,0.260955,0.186978,0.042803,0.153117,0.191503,0.263696,0.886419
2,191.0,1991.246073,1162.941107,9.0,1055.5,1922.0,3069.0,3988.0,191.0,38.617801,13.453839,9.0,30.0,35.0,50.0,64.0,191.0,240548.141361,66099.684527,81105.0,200339.5,235733.0,271333.5,583333.0,191.0,0.384089,0.255859,0.040297,0.16366,0.29016,0.590042,0.933463,191.0,0.521668,0.232109,0.039515,0.339147,0.526054,0.710067,0.96154,191.0,0.512655,0.139675,0.140677,0.411056,0.514344,0.621272,0.875194,191.0,-8.138027,3.470537,-21.213753,-9.795758,-7.845131,-5.859124,-0.155418,191.0,0.570896,0.202483,0.069692,0.432832,0.61207,0.737511,0.93722,191.0,0.273406,0.18678,0.0,0.172063,0.2198,0.293093,1.0,191.0,0.195158,0.087784,0.045658,0.147459,0.173836,0.218162,0.701682,191.0,0.178937,0.087207,0.083138,0.137639,0.171709,0.193605,0.73948
3,362.0,2032.856354,1158.649142,2.0,1050.5,2025.5,2986.75,4045.0,362.0,47.400552,12.641325,13.0,39.0,46.0,55.0,81.0,362.0,235437.743094,51746.631372,39237.0,207414.5,231612.5,263048.5,464024.0,362.0,0.262961,0.151635,0.033691,0.161968,0.221548,0.328465,0.854018,362.0,0.629642,0.187692,0.057641,0.514416,0.662265,0.771055,0.979528,362.0,0.659149,0.128489,0.268736,0.576428,0.671513,0.752819,0.986452,362.0,-6.563673,3.187407,-20.233032,-8.379485,-6.137369,-4.370457,0.0,362.0,0.662419,0.136002,0.147096,0.585815,0.679635,0.759897,0.926979,362.0,0.278008,0.157961,0.028644,0.172685,0.232028,0.34561,0.927603,362.0,0.317668,0.123557,0.04927,0.215768,0.306495,0.406426,0.869431,362.0,0.167692,0.067708,0.026447,0.138543,0.162737,0.186792,0.917872
4,45.0,1593.311111,956.094157,70.0,742.0,1650.0,2290.0,3712.0,45.0,48.822222,7.062349,24.0,48.0,50.0,53.0,61.0,45.0,230891.266667,49807.553429,138510.0,199439.0,218732.0,262266.0,383230.0,45.0,0.690632,0.194488,0.134646,0.563732,0.739324,0.821398,0.935842,44.0,0.454347,0.234524,0.056669,0.264715,0.428032,0.666871,0.919892,44.0,0.550516,0.127931,0.303742,0.453201,0.562645,0.637903,0.807066,45.0,-10.63491,3.949389,-17.736751,-14.249453,-10.007637,-7.516369,-3.93625,45.0,0.379426,0.202223,0.066562,0.211621,0.359715,0.535872,0.842153,45.0,0.214132,0.068262,0.10704,0.165312,0.206866,0.25977,0.403828,44.0,0.166744,0.049705,0.086803,0.135086,0.161337,0.19618,0.297551,45.0,0.20464,0.143057,0.095476,0.143683,0.167838,0.203527,0.850272
5,126.0,1876.857143,1174.590766,29.0,917.25,1898.5,2961.25,4042.0,126.0,35.777778,12.539786,4.0,23.0,40.0,43.0,62.0,126.0,245213.309524,61793.815606,120894.0,198761.75,245112.5,279699.25,460910.0,126.0,0.467456,0.22538,0.074521,0.25592,0.480274,0.659751,0.905483,126.0,0.567512,0.21905,0.077223,0.382555,0.588908,0.755605,0.95099,126.0,0.554175,0.146939,0.066774,0.440555,0.546764,0.664459,0.868273,126.0,-8.663272,3.598325,-18.897655,-10.459386,-8.69884,-6.107831,-0.183987,126.0,0.553453,0.175593,0.034354,0.433943,0.555594,0.69855,0.901512,126.0,0.306147,0.201582,0.109831,0.184914,0.221807,0.348455,0.973504,126.0,0.180474,0.064278,0.070258,0.145138,0.174425,0.202232,0.495933,126.0,0.220694,0.154136,0.060565,0.141492,0.173732,0.210508,0.901344
6,50.0,2074.52,1204.379546,49.0,1313.0,1861.0,3188.25,3997.0,50.0,41.94,16.301258,17.0,29.25,38.5,56.5,69.0,50.0,202543.18,66061.951662,64824.0,161182.25,204716.5,241695.5,350503.0,50.0,0.681698,0.24917,0.056032,0.583811,0.771561,0.873355,0.978858,50.0,0.362266,0.23916,0.026089,0.171402,0.286052,0.483587,0.846337,50.0,0.414836,0.141144,0.140307,0.31696,0.406,0.480466,0.814968,50.0,-11.78748,6.551953,-28.336552,-14.833411,-11.134587,-7.433227,0.0,50.0,0.37027,0.220322,0.00656,0.20519,0.322726,0.546991,0.919278,50.0,0.243444,0.139845,0.085121,0.154301,0.214776,0.272661,0.84327,50.0,0.169113,0.050288,0.041457,0.138266,0.167372,0.194245,0.335301,50.0,0.363768,0.285988,0.105765,0.157092,0.212529,0.653103,0.896772
7,334.0,2098.733533,1153.850867,8.0,1094.75,2279.5,3044.25,4038.0,334.0,40.215569,9.505882,10.0,35.0,40.0,47.0,62.0,334.0,240833.688623,70372.273149,44250.0,206178.5,235859.5,270208.75,687618.0,334.0,0.568815,0.229409,0.08742,0.369353,0.596983,0.765731,0.958145,334.0,0.441374,0.211451,0.000966,0.281647,0.417772,0.598114,0.968281,334.0,0.480768,0.146325,0.093303,0.368856,0.478476,0.585986,0.855663,334.0,-10.494408,4.444506,-31.585695,-13.473871,-10.036377,-7.498289,-0.700179,334.0,0.441347,0.186733,0.038723,0.300088,0.421483,0.560283,0.88995,334.0,0.256676,0.14906,0.040469,0.165256,0.215659,0.294277,0.930788,334.0,0.181769,0.064323,0.068001,0.141975,0.177688,0.20943,0.750974,334.0,0.183681,0.09232,0.04897,0.143636,0.170164,0.201521,0.869199
8,1305.0,2034.211494,1173.508758,1.0,1038.0,2000.0,3082.0,4039.0,1305.0,46.557854,13.517084,0.0,39.0,48.0,55.0,80.0,1305.0,238353.816092,53345.284383,12105.0,208792.0,235520.0,266935.0,496454.0,1305.0,0.38485,0.234314,0.025679,0.183119,0.324732,0.567157,1.0,1304.0,0.461033,0.21298,0.007439,0.286126,0.451695,0.631949,0.948828,1304.0,0.518784,0.139759,0.0771,0.426576,0.535433,0.618016,0.924704,1305.0,-7.860348,3.942626,-31.818595,-9.859515,-7.246037,-5.215312,0.0,1305.0,0.572793,0.183476,0.053861,0.448751,0.592599,0.71584,0.955176,1305.0,0.253459,0.149232,0.043054,0.164406,0.209489,0.295365,0.982437,1304.0,0.179042,0.061223,0.0,0.142849,0.173653,0.206931,0.611169,1305.0,0.186241,0.114011,0.0,0.14071,0.166006,0.192525,0.906675
9,59.0,2046.915254,1196.528878,45.0,1011.0,2214.0,3018.5,3926.0,59.0,69.305085,7.550415,60.0,63.0,65.0,78.0,81.0,59.0,235382.525424,46600.298569,103936.0,209581.5,245029.0,266583.5,366735.0,59.0,0.416741,0.187835,0.099372,0.240418,0.403909,0.554978,0.804591,59.0,0.479089,0.220505,0.126104,0.308498,0.434989,0.627294,0.93956,59.0,0.53567,0.139293,0.171039,0.434686,0.544113,0.634858,0.81482,59.0,-10.283998,4.250544,-21.273705,-13.566329,-9.7522,-7.281315,-1.484633,59.0,0.472082,0.143007,0.078777,0.379508,0.467775,0.556429,0.866568,59.0,0.236286,0.113263,0.097527,0.160338,0.221414,0.266982,0.826232,59.0,0.190613,0.080159,0.042651,0.14541,0.172052,0.224494,0.458855,59.0,0.174386,0.049992,0.097703,0.141274,0.173434,0.198262,0.404904


In [75]:
col_list = [col for col in df_test.columns if col not in [Config.row_id, Config.target]]

fig = make_subplots(
    rows=12,
    cols=1,
    #start_cell='bottom-left', # どのセルを起点とするか
    subplot_titles=col_list,
    shared_xaxes=False, # x軸を共有する場合
    shared_yaxes=True, # y軸を共有する場合
)

for idx, col in enumerate(col_list):
    for genre in range(11):
        fig.add_trace(
            go.Box(
                y=df_train[df_train[Config.target] == genre][col],
                name=f'Target_{genre}',
                marker=dict(color=color_palette['Bin'][0]),
                #line=dict(color='black')
            ),
            row=idx + 1,
            col=1,
        )

    fig.update_layout(
        title='Distribution',
        showlegend=False,
        barmode='overlay',
        width=1500,
        height=2500,
    )

fig.show()

In [66]:
col_list = [col for col in df_test.columns if col not in [Config.row_id, Config.target]]

for idx, col in enumerate(col_list):

    fig = make_subplots(
        rows=3,
        cols=4,
        #start_cell='bottom-left', # どのセルを起点とするか
        subplot_titles=[f'Target {i}' for i in range(11)],
        shared_xaxes=True, # x軸を共有する場合
        shared_yaxes=True, # y軸を共有する場合
    )

    for genre in range(11):
        fig.add_trace(
            go.Histogram(
                x=df_train[df_train[Config.target] == genre][col],
                #name=f'{col}(Train)',
                histnorm='probability',
                marker=dict(color=color_palette['Bin'][0]),
                #line=dict(color='black')
            ),
            row=int(genre/4) + 1,
            col=genre%4 + 1,
        )

    fig.update_layout(
        title=f'{col}',
        barmode='overlay',
        width=1500,
        height=1500,
    )
    fig.show()

In [63]:
# 相関行列を生成
from sklearn.preprocessing import OneHotEncoder

df_train_tmp = df_train.copy()

col_list = [Config.target]
output_col_list = [f'Target_{i}' for i in range(11)]

ohe = OneHotEncoder(sparse=False)
ohe.fit(df_train_tmp[col_list])

df_train_tmp[output_col_list] = ohe.transform(df_train_tmp[col_list])
df_train_tmp.drop(columns=col_list, inplace=True)

for i in range(11):
    col_list = [col for col in df_test.columns if col not in [Config.row_id]]
    col_list.append(f'Target_{i}')

    df_corr = df_train_tmp[col_list].corr()

    fig = go.Figure(layout=plotly_template['layout'])

    fig.add_trace(
        go.Heatmap(
            x=df_corr.columns,
            y=df_corr.index,
            z=np.array(df_corr),
            #annotation_text=np.around(np.array(df_corr), decimals=2),
            #y=total,
            name='Corr',
            colorscale='oxy'
            #mode='markers',
            #marker=dict(size=1, color='red'),
            #line=dict(color='black'),
            #customdata=total,
            #hovertemplate="<b>%{y:.2f}</b> <br>total: %{customdata:.2f}<extra></extra>",
        ),
    )

    fig.update_layout(
        width=1500,
        height=500,
    )

    fig.show()

Unnamed: 0,index,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,Target_0,Target_1,Target_2,Target_3,Target_4,Target_5,Target_6,Target_7,Target_8,Target_9,Target_10
count,4046.0,4046.0,4046.0,4046.0,4036.0,4038.0,4046.0,4046.0,4043.0,4038.0,4045.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0,4046.0
mean,2022.5,41.056105,242141.0,0.346455,0.4641,0.504347,-7.715659,0.603663,0.265986,0.198655,0.214336,0.007909,0.050667,0.047207,0.089471,0.011122,0.031142,0.012358,0.082551,0.322541,0.014582,0.33045
std,1168.123923,16.165708,85202.41,0.241004,0.225052,0.158415,4.10964,0.20102,0.155769,0.083557,0.154281,0.088591,0.219345,0.212108,0.285458,0.104886,0.173723,0.110491,0.275236,0.467506,0.119888,0.470433
min,0.0,0.0,5998.0,0.0,0.0,0.013839,-37.820457,0.003383,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,1011.25,31.0,204442.0,0.149705,0.276384,0.392581,-9.775363,0.462137,0.168527,0.148698,0.143295,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,2022.5,42.0,235873.5,0.250711,0.450211,0.510993,-7.18946,0.634078,0.218486,0.18319,0.171708,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,3033.75,52.0,272402.0,0.523088,0.644786,0.617371,-4.876553,0.768768,0.317773,0.224999,0.205446,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
max,4045.0,82.0,2135773.0,1.0,0.989661,1.0,0.0,1.0,1.0,0.886806,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


## 欠損値の確認
- 必ずpositivenessが欠損になっている

In [80]:
df_train[df_train.isnull().any(axis=1)]

Unnamed: 0,index,genre,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,tempo,region
169,169,10,29,22132,0.183041,,0.074635,-2.684752,0.424803,0.538192,0.27594,0.87466,57-63,region_L
332,332,10,18,5998,0.073058,,,-3.492185,0.482935,,,0.933298,0-40,region_P
1147,1147,10,17,5999,0.246239,,,-11.188019,0.138494,,,,0-40,region_P
1824,1824,10,29,12904,0.878264,,,-37.820457,0.071458,0.30617,,0.209245,0-40,region_B
1859,1859,10,17,6002,0.16667,,,-3.675506,0.630632,,,0.878543,0-40,region_P
1894,1894,8,28,12105,0.894437,,,-30.401654,0.294521,0.68064,,0.353691,0-40,region_E
1942,1942,10,30,13559,0.125577,,,0.0,0.856918,0.521328,,0.809948,0-40,region_L
2088,2088,10,30,20748,0.149761,,0.222094,-3.491594,0.814625,0.243798,0.302897,0.866409,64-76,region_L
2290,2290,4,54,256385,0.862298,,,-14.370051,0.127264,0.260063,,0.19574,0-40,region_P
2619,2619,10,28,12907,0.174814,,,-4.74726,0.887299,0.493169,,0.810391,0-40,region_L


In [81]:
df_test[df_test.isnull().any(axis=1)]

Unnamed: 0,index,popularity,duration_ms,acousticness,positiveness,danceability,loudness,energy,liveness,speechiness,instrumentalness,tempo,region
65,4111,32,8307,0.705052,,,-5.838485,0.677236,0.330491,,0.450735,0-40,region_L
161,4207,29,5826,0.303558,,,-3.612817,0.770469,,,0.834549,0-40,region_L
200,4246,16,5998,,,,-9.770362,,,,,0-40,region_P
313,4359,32,15150,0.021372,,0.207516,-1.013198,0.906015,0.516378,0.272722,0.833171,77-96,region_L
848,4894,32,28304,0.13387,,0.313016,-4.504644,0.654447,0.147979,0.144175,0.85165,77-96,region_L
935,4981,28,6199,0.800011,,,-30.216122,0.147273,,,0.125978,0-40,region_B
1750,5796,28,6506,0.162008,,,-2.325072,0.845391,,,0.876464,0-40,unknown
1790,5836,18,17375,0.093081,,0.280091,-6.356741,0.459276,0.339337,0.14023,0.712319,64-76,region_P
1834,5880,28,9705,0.088208,,,-6.473317,0.919014,0.506552,,0.742145,0-40,region_L
2018,6064,34,13629,0.067932,,,-3.460831,0.921675,0.386363,,0.7686,0-40,region_L
