In [None]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import graphviz
sns.set_style('darkgrid') 

GLOBAL_SEED = 42

In [None]:
# @title Installing SHAP
try:
  import shap
except:
  print("Installing SHAP")
  !pip install shap

In [None]:

Localization = ["Mitochondria"]

In [None]:

loc=0
print("==============================================================")
print("Data set is", Localization[loc])
print("==============================================================")

df = pd.read_csv("./train/Mitochondria_all_fea_ovr_dataset.csv")
df["label"]  = df["label"].replace({Localization[loc]:1, "AllRest":0})
X = df.drop(["SampleName","label"], axis=1)
y = df["label"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1, shuffle=True)
model = DecisionTreeClassifier(random_state=GLOBAL_SEED, min_samples_split=2)
# train the model data - trining dataset
model.fit(X_train, y_train)

# Visualize how model classified the entire data
#tree_graph = export_graphviz(model, out_file=None, feature_names = X_train.columns, rounded=True, filled=True)
#graphviz.Source(tree_graph)

In [None]:
# Initialize JavaScript visualizations in notebook environment
shap.initjs()
# Define a tree explainer for the built model
explainer = shap.TreeExplainer(model)
# obtain shap values for the first row of the test data
#shap_values = explainer.shap_values(X_test.iloc[0])
# obtain shap values for the test data
shap_values = explainer.shap_values(X_test)
#shap.summary_plot(shap_values[1], X_test, plot_type='bar', max_display=20)
shap.summary_plot(shap_values[1], X_test)
#shap.summary_plot(shap_values[1], X_test, plot_type='violin')
feature_names = X_test.columns
vals= np.abs(shap_values).mean(0)
feature_importance = pd.DataFrame(list(zip(feature_names, sum(vals))), columns=['col_name','feature_importance_vals'])
feature_importance.sort_values(by=['feature_importance_vals'], ascending=False,inplace=True)
print(feature_importance.head(20))
print()
print("==================================================")
