In [1]:
%matplotlib inline

In [2]:
import warnings

import pandas as pd

from lib.report import mlflow_log_classification_report, mlflow_log_model
import lib.constants as constants

from sklearn.pipeline import Pipeline

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV

from sklearn.ensemble import RandomForestClassifier

import mlflow
import mlflow.sklearn
from mlflow.data.pandas_dataset import PandasDataset

In [3]:
warnings.filterwarnings("ignore", "Setuptools is replacing distutils.")
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=UserWarning)

# Music Genre Classifier Model Selection

## Optimizing Random Forest Classifier

In [4]:
sample_length = 3

In [5]:
data_file = f"../data/{sample_length}_seconds_song_features.csv".format(sample_length)

In [6]:
songs = pd.read_csv(data_file, low_memory=False)

In [7]:
songs

Unnamed: 0,zero_crossings_max,zero_crossings_min,zero_crossings_mean,zero_crossings_std,zero_crossings_kurtosis,zero_crossings_skew,centroid_max,centroid_min,centroid_mean,centroid_std,...,mfcc_11_skew,mfcc_12_max,mfcc_12_min,mfcc_12_mean,mfcc_12_std,mfcc_12_kurtosis,mfcc_12_skew,tempo,genre,file
0,0.210449,0.047363,0.081350,0.022665,11.503471,2.688745,4031.354256,1247.016418,1761.351126,389.970271,...,-0.461595,9.864202,-23.182499,-4.651091,5.882950,1.015412,-0.612315,129.199219,blues,blues.00000.wav
1,0.164551,0.037598,0.087709,0.032044,-0.850489,0.472236,3222.274307,1022.732614,1822.016049,304.312041,...,0.000940,14.760537,-17.337769,-7.867007,5.448279,2.613537,1.247251,123.046875,blues,blues.00000.wav
2,0.131348,0.032227,0.071626,0.020605,-0.307913,0.340260,2999.614979,1037.125090,1793.037434,331.091442,...,-0.091285,9.158812,-18.102434,-4.826131,6.191893,-0.505952,-0.358207,123.046875,blues,blues.00000.wav
3,0.125000,0.037598,0.069733,0.017171,0.342042,0.576718,2509.335877,1047.974918,1661.406968,331.384160,...,0.114771,19.266750,-20.557999,-4.382119,7.674499,-0.079892,0.245445,123.046875,blues,blues.00000.wav
4,0.113770,0.027344,0.070297,0.016781,-0.428186,-0.091625,2675.131530,1129.060230,1635.331464,281.302522,...,-0.587330,10.810911,-22.698643,-5.914538,6.539700,-0.368645,-0.088289,123.046875,blues,blues.00000.wav
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9265,0.255371,0.062988,0.149547,0.038969,-0.362665,0.455421,4219.843386,2126.538609,2828.911302,406.251026,...,0.459619,18.043543,-13.256772,2.580344,5.558560,0.478881,0.352665,151.999081,rock,rock.00027.wav
9266,0.348633,0.028809,0.125189,0.062630,1.906010,1.320978,5054.668975,1485.373859,2868.156921,685.789202,...,0.061615,9.562820,-24.992477,-10.230426,8.822643,-0.953396,0.389956,151.999081,rock,rock.00027.wav
9267,0.490723,0.017090,0.110215,0.091016,5.348066,2.280798,5436.874338,1299.288815,2465.205517,948.630703,...,-0.124412,14.526768,-27.832321,-4.105305,8.967973,-0.430528,0.117361,151.999081,rock,rock.00027.wav
9268,0.573242,0.012207,0.117123,0.118646,5.337525,2.434055,5770.027623,1378.406711,2462.673377,1028.748903,...,-0.457346,24.023312,-9.351840,4.406565,7.542349,-0.360459,0.427651,143.554688,rock,rock.00027.wav


In [8]:
song_genres = songs["genre"]

In [9]:
label_encoder = LabelEncoder()

In [10]:
encoded_song_genres = label_encoder.fit_transform(song_genres)

In [11]:
song_features = songs.drop(columns=["genre" , "file"], axis=1)

## Test, train and validation split

In [12]:
constants.RANDOM_STATE

1984

In [13]:
song_features_intermediate_train, song_features_test, song_genres_intermediate_train, song_genres_test = \
    train_test_split(song_features, song_genres, test_size = 0.1, random_state=constants.RANDOM_STATE)

In [14]:
song_features_train, song_features_val, song_genres_train, song_genres_val = \
    train_test_split(song_features_intermediate_train, song_genres_intermediate_train, test_size = 0.2,\
                     random_state=constants.RANDOM_STATE)

## Peraring traning pieline

In [15]:
train_pipeline = Pipeline([
    ("sndard_scaler", StandardScaler()),
    ("reduce_dimension", None),
    ("random_forest", RandomForestClassifier())
])

## Fiding Best Random Forest Classifier

In [16]:
dataset: PandasDataset = mlflow.data.from_pandas(songs, source=data_file)

### By min sample split

In [17]:
experiment = mlflow.create_experiment(name = f"Random Forest, min sample split - {sample_length} sec".format(sample_length))

In [18]:
run = mlflow.start_run(experiment_id=experiment)

In [19]:
mlflow.log_artifact(data_file)
mlflow.log_input(dataset)

In [20]:
grid_search = GridSearchCV(train_pipeline,
    param_grid={
        "random_forest__max_depth" : [ 10, 15, 20, 25, 50, 100, 150, 200, 400, None],
        "random_forest__min_samples_split": [2, 3, 5, 10, 20]
    }, cv = 5, n_jobs=8)

In [21]:
grid_search.fit(song_features_train, song_genres_train)

In [22]:
grid_search.best_params_

{'random_forest__max_depth': 200, 'random_forest__min_samples_split': 2}

In [23]:
train_score = grid_search.best_estimator_.score(song_features_train, song_genres_train)

In [24]:
train_score

0.9994006592747977

In [25]:
validation_score = grid_search.best_estimator_.score(song_features_val, song_genres_val)

In [26]:
validation_score

0.8585979628520072

In [27]:
mlflow_log_classification_report(song_features_val, song_genres_val,\
                            grid_search.best_estimator_, target_names=label_encoder.classes_)

              precision    recall  f1-score   support

       blues       0.86      0.85      0.85       198
   classical       0.95      0.88      0.91       200
     country       0.82      0.79      0.80       188
       disco       0.87      0.81      0.84       194
      hiphop       0.82      0.92      0.86       159
        jazz       0.84      0.84      0.84       166
       metal       0.92      0.88      0.90       192
         pop       0.91      0.91      0.91       176
      reggae       0.84      0.87      0.86       174
        rock       0.43      0.91      0.58        22

    accuracy                           0.86      1669
   macro avg       0.83      0.86      0.84      1669
weighted avg       0.87      0.86      0.86      1669



In [28]:
mlflow_log_model(grid_search, train_score, validation_score)
mlflow.end_run()

### By min sample leafs

In [29]:
experiment = mlflow.create_experiment(name = f"Random Forest, min sample leaf - {sample_length} sec".format(sample_length))

In [30]:
run = mlflow.start_run(experiment_id=experiment)

In [31]:
mlflow.log_artifact(data_file)
mlflow.log_input(dataset)

In [32]:
grid_search = GridSearchCV(train_pipeline,
    param_grid={
        "random_forest__max_depth" : [ 10, 20, 50, 100, 150, 200, 400, 600, None],
        "random_forest__min_samples_leaf": [1, 2, 3, 5, 10, 20]
    }, cv = 5, n_jobs=8)

In [33]:
grid_search.fit(song_features_train, song_genres_train)

In [34]:
grid_search.best_params_

{'random_forest__max_depth': 50, 'random_forest__min_samples_leaf': 1}

In [35]:
train_score = grid_search.best_estimator_.score(song_features_train, song_genres_train)

In [36]:
train_score

0.9994006592747977

In [37]:
validation_score = grid_search.best_estimator_.score(song_features_val, song_genres_val)

In [38]:
validation_score

0.8550029958058718

In [39]:
mlflow_log_classification_report(song_features_val, song_genres_val,\
                            grid_search.best_estimator_, target_names=label_encoder.classes_)

              precision    recall  f1-score   support

       blues       0.85      0.86      0.86       192
   classical       0.96      0.89      0.92       197
     country       0.81      0.78      0.80       187
       disco       0.85      0.82      0.83       187
      hiphop       0.83      0.91      0.87       163
        jazz       0.82      0.79      0.80       173
       metal       0.94      0.87      0.90       197
         pop       0.91      0.91      0.91       175
      reggae       0.83      0.86      0.85       176
        rock       0.43      0.91      0.58        22

    accuracy                           0.86      1669
   macro avg       0.82      0.86      0.83      1669
weighted avg       0.86      0.86      0.86      1669



In [40]:
mlflow_log_model(grid_search, train_score, validation_score)
mlflow.end_run()