In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
import xgboost as xgb
import shap
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import r2_score, auc, f1_score


import warnings
warnings.filterwarnings('ignore')

sys.path.insert(0, '../src/visualization/')
import visualize as vis

In [None]:
df = pd.read_csv('../data/processed/CriticalPath_Data_EM_Confidential_lessNoise.csv').drop(columns=['Unnamed: 0'])

## XGBoost Model to see which features most impact enrollment.

### Starting Parameters

## Split into training and test data, and fit a regression model.

In [None]:
X = df.drop(columns=['Enrolled','Admission_status',
                     'Unique_student_ID']).select_dtypes([float,bool,int]).fillna(-999)

Y = df['Enrolled'].fillna(-999)


X_train, X_test, y_train, y_test = train_test_split(X, Y)

best_params = {'colsample_bytree': 0.8,
 'learning_rate': 0.05,
 'max_depth': 5,
 'min_child_weight': 11,
 'missing': -999,
 'n_estimators': 500,
 'nthread': 4,
 'seed': 42,
 'silent': 1,
 'subsample': 0.8}  # found from GridSearchCV (013-st-model_paramters.ipynb)

In [None]:
model = xgb.XGBRegressor(**best_params)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print( "R2 Score: ", r2_score(y_test, y_pred) )

## Plot feature importance

In [None]:
for importance_type in ['weight']:
    vis.my_plot_importance(model,figsize=(10,10),importance_type=importance_type);
    plt.tight_layout()
    plt.title('Feature Importance: importance_type = %s' %importance_type)
    plt.savefig("../reports/figures/feature_importance/feature_importance_%s.png" %importance_type)

## Plot shapley values.
* *An intuitive way to understand the Shapley value is the following illustration: The feature values enter a room in random order. All feature values in the room participate in the game (= contribute to the prediction). The Shapley value of a feature value is the average change in the prediction that the coalition already in the room receives when the feature value joins them.*
 
* *The interpretation of the Shapley value is: Given the current set of feature values, the contribution of a feature value to the difference between the actual prediction and the mean prediction is the estimated Shapley value.*





[cristophm.gihub.io](https://christophm.github.io/interpretable-ml-book/shapley.html):

In [None]:
shap.initjs()
# explain the model's predictions using SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

In [None]:
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue (these force plots are introduced in our Nature BME paper).

In [None]:
sum_plot = shap.summary_plot(shap_values, X, max_display=10, show=False,
                  plot_type='dot')
plt.tight_layout()
plt.savefig("../reports/figures/feature_importance/shapley_summary.png")

## Plot this as a bar chart.

In [None]:
shap.summary_plot(shap_values, X, max_display=10, show=False,
                  plot_type='bar')
plt.tight_layout()
plt.savefig("../reports/figures/feature_importance/shapley_summary_bar.png")

## Individual features.

#### X axis is feature value
#### Y axis is the associated shapley value (ouput impact)

#### Red/Blue is a value of potential interaction effect

In [None]:
for feature in [
    'State_grants','ADMT_DEC_CODE',
    'Number_of_campus_visits','Internal_Academic_rating',
    'HS_Percentile_rank','Student_income_AGI','Parent_income_AGI',
    'Pell_grant','Need_by_FM','Year_of_entry']:

#     print(feature)
    plt.figure()
    shap.dependence_plot(feature, shap_values, X,show=False)
    plt.tight_layout()
    plt.savefig("../reports/figures/feature_importance/individual_feature-%s.png" % feature)    