# Imports and Setup

In [None]:
import os
os.chdir('..')

In [None]:
import pandas as pd
import numpy as np

from lightgbm import LGBMClassifier
import shap

from constants import DATA_DIR, DROP_COLS
from modeling.load_train_test import load_train_test
from modeling.predict_and_evaluate import predict_and_evaluate

np.random.seed(42)
shap.initjs()

# Create and train model
##### Note: To change models, only this cell needs to be updated

In [None]:
X_train, X_test, y_train, y_test = load_train_test(scale=False)

HYPERPARAMS = {
    'learning_rate': 0.1, 
    'n_estimators': 25,
    'num_leaves': 40
}

model = LGBMClassifier(
    learning_rate=HYPERPARAMS['learning_rate'], 
    n_estimators=HYPERPARAMS['n_estimators'], 
    num_leaves=HYPERPARAMS['num_leaves'], 
    is_unbalance=True) # TODO: SMOTE

model.fit(X_train, y_train)

# Predict out-of-sample genres and evaluate accuracy

In [None]:
y_pred, y_prob = predict_and_evaluate(model, new_data=X_test, truth_data=y_test)

# Explore preds with SHAP

In [None]:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)

In [30]:
OBS = 4
shap.force_plot(explainer.expected_value[OBS], shap_values[OBS][0, :], X_train.iloc[0, :])

In [None]:
#shap_obj = explainer(X_train)
#shap.plots.beeswarm(shap_obj)

## Notes:
* Balance classes
* Tune models
* Compare errors of different classes
* Is one class confused for another most often, i.e. country confused for rock?'