In [27]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score
from interpret.blackbox import LimeTabular
from interpret import show

In [28]:
%autoreload 2
from utils import DataLoader

## Model Training

In [29]:
data_loader = DataLoader()
data_loader.load_data("../preprocessing/ADNI_preprocessed.csv")

X_train, X_test, y_train, y_test = data_loader.get_data_split()
print(X_train.shape)
print(X_test.shape)

(3567, 40)
(892, 40)


In [30]:
# Fit onto Random Forest Classifier
rf = RandomForestClassifier(n_estimators=100)
rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)

f1 = f1_score(y_test, rf_pred)
accuracy = accuracy_score(y_test, rf_pred)
print(f"Random Forests model F1 score: {f1}, Accuracy: {accuracy}")

Random Forests model F1 score: 0.8044077134986226, Accuracy: 0.9204035874439462


### Balancing dataset

In [31]:
# Try oversampling to balance dataset and increase accuracy
X_train_over, y_train_over = data_loader.oversample_data(X_train, y_train)
print(X_train_over.shape)
print(X_test.shape)

(5476, 40)
(892, 40)


In [32]:
# Fit onto Random Forest Classifier
rf_over = RandomForestClassifier(n_estimators=100)
rf_over.fit(X_train_over, y_train_over)
rf_pred = rf_over.predict(X_test)

f1 = f1_score(y_test, rf_pred)
accuracy = accuracy_score(y_test, rf_pred)
print(f"Random Forests model F1 score: {f1}, Accuracy: {accuracy}")

Random Forests model F1 score: 0.7919799498746867, Accuracy: 0.9069506726457399


## LIME Application

In [33]:
# 0 means no AD, 1 means AD

lime = LimeTabular(predict_fn=rf.predict_proba, data=X_train, random_state=1, )

# Get local explanations
lime_local = lime.explain_local(X_test[-20:], y_test[-20:], name='LIME')

show(lime_local) 



The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html
The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_table package is deprecated. Please replace
`import dash_table` with `from dash import dash_table`

Also, if you're using any of the table format helpers (e.g. Group), replace 
`from dash_table.Format import Group` with 
`from dash.dash_table.Format import Group`
  import dash_table as dt


In [34]:
# Balanced dataset
lime = LimeTabular(predict_fn=rf_over.predict_proba, data=X_train_over, random_state=1, )

# Get local explanations
lime_local = lime.explain_local(X_test[-20:], y_test[-20:], name='LIME')

show(lime_local)


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, but RandomForestClassifier was fitted with feature names


X does not have valid feature names, bu