In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.ensemble import RandomForestClassifier

In [3]:
pd.set_option('display.max_columns', None)

In [4]:
# 'pyr_four_nostr.csv' is the original classifier train set
# 'pyr_four_nostr_mmstats.csv' adds in 10 additional mito-to-mito distance statistics

pyr_four = pd.read_csv('pyr_four_nostr.csv', index_col=[0])

In [5]:
# split test size
testsize = 0.30

# number of estimators
n_est = 100

# random state
random_st = 1

# decision tree classifier conditions
min_sampleaf = 1
min_sampsplit = 2 
min_weightfractionleaf = 0.0

# First run

## Test-train-split 1 using random state = 1 (from above variables)

In [6]:
X = pyr_four.drop('compartment',axis=1)
y = pyr_four['compartment']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testsize, random_state=random_st)

## Dtree1

In [7]:
dtree = DecisionTreeClassifier()
dtree.fit(X_train,y_train)
dtree.fit(X_train,y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=min_sampleaf,
            min_samples_split=min_sampsplit, min_weight_fraction_leaf=min_weightfractionleaf,
            random_state=random_st, splitter='best')

DecisionTreeClassifier(random_state=1)

## Dtree1 classification report and confusion matrix

In [8]:
predictions = dtree.predict(X_test)

In [9]:
print(classification_report(y_test,predictions))

              precision    recall  f1-score   support

      Apical       0.78      0.79      0.78      5962
      Axonal       0.72      0.73      0.72      3751
       Basal       0.84      0.84      0.84     16534
     Somatic       0.96      0.96      0.96     28025

    accuracy                           0.89     54272
   macro avg       0.82      0.83      0.83     54272
weighted avg       0.89      0.89      0.89     54272



In [10]:
print(confusion_matrix(y_test,predictions))

[[ 4685    78   975   224]
 [   70  2733   816   132]
 [ 1023   851 13836   824]
 [  230   135   808 26852]]


## Random forest 1 (rf1) using random state 1

In [11]:
start_time = time.time()
rfc = RandomForestClassifier(n_estimators=n_est, random_state=random_st)
rfc.fit(X_train, y_train)
elapsed_time = time.time() - start_time

In [12]:
print(f"Elapsed time to compute the Random Forest Classifer: {elapsed_time:.3f} seconds")

Elapsed time to compute the Random Forest Classifer: 43.332 seconds


## Rf1 confusion matrix and classification report

In [13]:
rfc_pred = rfc.predict(X_test)
print(confusion_matrix(y_test,rfc_pred))

[[ 4704    30   990   238]
 [   51  2907   637   156]
 [  640   414 14782   698]
 [   37    10   317 27661]]


In [14]:
print(classification_report(y_test,rfc_pred))

              precision    recall  f1-score   support

      Apical       0.87      0.79      0.83      5962
      Axonal       0.86      0.77      0.82      3751
       Basal       0.88      0.89      0.89     16534
     Somatic       0.96      0.99      0.97     28025

    accuracy                           0.92     54272
   macro avg       0.89      0.86      0.88     54272
weighted avg       0.92      0.92      0.92     54272



In [15]:
# X_colnames = [str(X.columns[i]) for i in range(X.shape[1])]

In [16]:
# from https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html#sphx-glr-auto-examples-ensemble-plot-forest-importances-py

#start_time = time.time()
#feature_names = X_colnames
#forest = RandomForestClassifier(random_state=random_st)
#forest.fit(X_train, y_train)

#importances = forest.feature_importances_
#std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
#elapsed_time = time.time() - start_time

In [17]:
# print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")

In [18]:
#forest_importances = pd.Series(importances, index=feature_names)

#fig, ax = plt.subplots()
#forest_importances.plot.bar(yerr=std, ax=ax)
#ax.set_title("Feature importances using MDI")
#ax.set_ylabel("Mean decrease in impurity")
#fig.tight_layout()

# Second run

## Test-train-split
To get the big jump in accuracy, the test-train-split must be different from run 1
### This is the first requirement to get the big jump in accuracy
Here, random state is set to 0 (different from run 1) so the big jump in accuracy is obtained  
If you set random state to 1 (same as run 1), the accuracy results will be identical to run 1

In [19]:
# re-run without updating rfc
# these 3 lines are required for the increased accuracy; if you uncomment them, the following analysis is not different from above
# if you set the random state below to be the same as the random state in the first run there is no jump in accuracy
# if you set the random state to be different from first run, there is a substantial jump in accuracy
X = pyr_four.drop('compartment',axis=1)
y = pyr_four['compartment']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testsize, random_state=0) # set random state here

## Dtree 2

In [20]:
dtree = DecisionTreeClassifier()
dtree.fit(X_train,y_train)
dtree.fit(X_train,y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=min_sampleaf,
            min_samples_split=min_sampsplit, min_weight_fraction_leaf=min_weightfractionleaf,
            random_state=random_st, splitter='best')

DecisionTreeClassifier(random_state=1)

## Dtree2 classification report and confusion matrix

In [21]:
predictions = dtree.predict(X_test)
print(classification_report(y_test,predictions))

              precision    recall  f1-score   support

      Apical       0.78      0.80      0.79      5941
      Axonal       0.73      0.74      0.73      3613
       Basal       0.85      0.85      0.85     16467
     Somatic       0.96      0.96      0.96     28251

    accuracy                           0.89     54272
   macro avg       0.83      0.84      0.83     54272
weighted avg       0.89      0.89      0.89     54272



In [22]:
print(confusion_matrix(y_test,predictions))

[[ 4729    85   927   200]
 [   70  2683   740   120]
 [ 1010   784 13935   738]
 [  272   144   763 27072]]


## if rfc from run 1 is used, you get the big jump in accuracy 
## if rfc is reset, you get identical accuracy results as run 1
### This is the second requirement to get the big jump in accuracy

In [23]:
# testing to see what happens when rfc is reset
# if you uncomment these two lines, which resets rfc, you no longer get the big jump in accuracy
#rfc = RandomForestClassifier(n_estimators=n_est, random_state=random_st)
#rfc.fit(X_train, y_train)

## Rf2 classification report and confusion matrix

In [24]:
rfc_pred = rfc.predict(X_test)
print(confusion_matrix(y_test,rfc_pred))

[[ 5563     7   315    56]
 [   11  3373   181    48]
 [  201   121 15946   199]
 [    9     8    83 28151]]


In [25]:
print(classification_report(y_test,rfc_pred))

              precision    recall  f1-score   support

      Apical       0.96      0.94      0.95      5941
      Axonal       0.96      0.93      0.95      3613
       Basal       0.96      0.97      0.97     16467
     Somatic       0.99      1.00      0.99     28251

    accuracy                           0.98     54272
   macro avg       0.97      0.96      0.96     54272
weighted avg       0.98      0.98      0.98     54272



In [26]:
# X_colnames = [str(X.columns[i]) for i in range(X.shape[1])]

In [27]:
# from https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html#sphx-glr-auto-examples-ensemble-plot-forest-importances-py

#start_time = time.time()
#feature_names = X_colnames
#forest = RandomForestClassifier(random_state=random_st)
#forest.fit(X_train, y_train)

#importances = forest.feature_importances_
#std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
#elapsed_time = time.time() - start_time

In [28]:
#print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")

In [29]:
#forest_importances = pd.Series(importances, index=feature_names)

#fig, ax = plt.subplots()
#forest_importances.plot.bar(yerr=std, ax=ax)
#ax.set_title("Feature importances using MDI")
#ax.set_ylabel("Mean decrease in impurity")
#fig.tight_layout()