# Justifying a random forest's predictions
## Libraries and experimental data set

In [1]:
# Libraries to be used
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from lime.lime_tabular import LimeTabularExplainer
from treeinterpreter import treeinterpreter
from scipy.sparse import hstack

# Hide warnings
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Load data set for experiments
from sklearn.datasets import load_wine
dataset = load_wine()
X = pd.DataFrame(dataset.data, columns = dataset.feature_names)
y = dataset.target

In [3]:
# Split data into training and test sets
Xtrain, Xtest, ytrain, ytest = train_test_split(
    X, y,
    stratify = y,
    train_size = 100
)

## Exploratory analysis and random forest classifiers

In [4]:
# Take a glance at feature distributions, broken down by class
Xtrain.groupby(ytrain).quantile([.1,.9])

Unnamed: 0,Unnamed: 1,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline
0,0.1,13.054,1.574,2.154,15.04,94.2,2.402,2.43,0.204,1.442,3.83,0.882,2.78,802.0
0,0.9,14.228,3.758,2.65,19.88,120.2,3.294,3.556,0.34,2.188,7.02,1.166,3.588,1424.0
1,0.1,11.641,0.985,1.746,16.0,80.9,1.609,1.277,0.208,0.938,2.049,0.799,2.109,371.3
1,0.9,13.056,3.904,2.631,24.05,112.7,2.908,2.994,0.523,2.302,4.465,1.283,3.381,721.2
2,0.1,12.632,1.808,2.242,19.3,87.2,1.334,0.502,0.244,0.79,4.7,0.57,1.318,476.0
2,0.9,13.856,5.1,2.68,24.7,120.0,2.02,1.262,0.6,1.554,10.224,0.832,1.866,782.0


In [5]:
# Train a random forest classifier for each of the three classes
clf = []
for clas in range(3):
    clf.append(
        RandomForestClassifier(
            n_estimators = 100, n_jobs = -1
        ).fit(Xtrain, ytrain == clas)
    )

In [6]:
# Check the AUC's of the classifiers on the test data
for clas in range(3):
    print(roc_auc_score(ytest == clas, clf[clas].predict_proba(Xtest)[:,1]))

0.99926035503
0.999313658202
1.0


## Explanations from LIME

In [7]:
# Create a LIME explainer for tabular data
explainer = LimeTabularExplainer(Xtrain.values, feature_names = Xtrain.columns)

def explain_row(clf, row, num_reasons = 2):
    '''
    Produce LIME explanations for a single row of data.
        * `clf` is a binary classifier (with a predict_proba method),
        * `row` is a row of features data,
        * `num_reasons` (default 2) is the number of 
          reasons/explanations to be produced.
          
    '''
    exp = [
        exp_pair[0] for exp_pair in     # Get each explanation (a string)
        explainer.explain_instance(     # from the LIME explainer
            row, clf.predict_proba,     # for the given row and classifier
            labels = [1],               # and label 1 ("positives")
            num_features = num_reasons  # for up to `num_reasons` explanations
        ).as_list()
        if exp_pair[1] > 0              # but only for positive explanations 
    ][:num_reasons]
    
    # Fill in any missing explanations with blanks
    exp += [''] * (num_reasons - len(exp))  
    return exp


def predict_explain(rf, X, num_reasons = 2):
    '''
    Produce scores and LIME explanations for every row in a data frame.
        * `rf` is a binary classifier with a predict_proba method,
        * `X` is the features data frame,
        * `num_reasons` (default 2) is the number of 
          reasons/explanations to be produced for each row.
          
    '''
    # Prepare the structure to be returned
    pred_ex = Xtest[[]]
    
    # Get the scores from the classifier
    pred_ex['SCORE'] = rf.predict_proba(X)[:,1]
    
    # Get the reasons/explanations for each row
    cols = zip(
        *Xtest.apply(
            lambda x: explain_row(rf, x, num_reasons), 
            axis = 1, raw = True
        )
    )
    
    # Return the results
    for n in range(num_reasons):
        pred_ex['REASON%d' % (n+1)] = next(cols)
    return pred_ex


### Explanations for top cases predicted to belong to class 0

In [8]:
%%time
predict_explain(clf[0], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 40.5 s


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
56,1.0,proline > 941.25,flavanoids > 2.76,0
53,1.0,proline > 941.25,flavanoids > 2.76,0
14,1.0,proline > 941.25,flavanoids > 2.76,0
10,1.0,proline > 941.25,flavanoids > 2.76,0
51,1.0,proline > 941.25,flavanoids > 2.76,0
54,1.0,proline > 941.25,flavanoids > 2.76,0
15,1.0,proline > 941.25,flavanoids > 2.76,0
8,0.99,proline > 941.25,flavanoids > 2.76,0
48,0.99,proline > 941.25,flavanoids > 2.76,0
9,0.99,proline > 941.25,flavanoids > 2.76,0


### Explanations for top cases predicted to belong to class 1

In [9]:
%%time
predict_explain(clf[1], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 43.4 s


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
80,1.0,color_intensity <= 3.20,alcohol <= 12.36,1
106,1.0,alcohol <= 12.36,magnesium <= 88.00,1
89,1.0,color_intensity <= 3.20,alcohol <= 12.36,1
108,0.99,color_intensity <= 3.20,alcohol <= 12.36,1
88,0.98,color_intensity <= 3.20,alcohol <= 12.36,1
114,0.98,color_intensity <= 3.20,alcohol <= 12.36,1
93,0.98,color_intensity <= 3.20,alcohol <= 12.36,1
86,0.98,color_intensity <= 3.20,alcohol <= 12.36,1
97,0.98,color_intensity <= 3.20,alcohol <= 12.36,1
94,0.97,alcohol <= 12.36,proline <= 498.75,1


### Explanations for top cases predicted to belong to class 2

In [10]:
%%time
predict_explain(clf[2], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 42 s


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
148,0.99,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
164,0.97,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
165,0.96,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
163,0.96,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
157,0.95,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
147,0.95,flavanoids <= 1.17,hue <= 0.79,2
135,0.94,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
154,0.94,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
153,0.94,flavanoids <= 1.17,od280/od315_of_diluted_wines <= 1.83,2
151,0.91,od280/od315_of_diluted_wines <= 1.83,hue <= 0.79,2


## Explanations by tree interpretation

In [11]:
def splitlist(tree_):
    '''
    For each of the nodes in a given decision tree, find the feature and 
    threshold used to arrive at that node, along with the additional score
    gained at that node.
    
    Returns a list of triples (feature, threshold, score gained), with one for
    each node in the tree.
    
    '''
    # Prepare the list to be returned
    l = [] 
    for _ in range(len(tree_.children_left)): l.append((-2,-2,0))
        
    # For each node's children, find the feature, threshold, and score gained
    for i,c in enumerate(zip(
        tree_.children_left, tree_.children_right, 
        tree_.feature, tree_.threshold
    )):
        l[c[0]] = (c[2], c[3], tree_.value[c[0], 0, 1] - tree_.value[i, 0, 1])
        l[c[1]] = (c[2], c[3], tree_.value[c[1], 0, 1] - tree_.value[i, 0, 1])
        
    return l


def predict_explain(rf, X, num_reasons = 2):
    '''
    Produce scores and explanations for an entire data frame.
        * `rf` is a RandomForestClassifier,
        * `X` is the features data frame,
        * `num_reasons` (default 2) is the number of 
          reasons/explanations to be produced for each row.
          
    '''    
    # Prepare the structure to be returned    
    pred_ex = X[[]]
    
    # Get scores and feature contributions from a tree interpreter
    pred, _, contrib = treeinterpreter.predict(rf, X)
    pred = pred[:,1]    
    pred_ex['SCORE'] = pred
    
    # Reformat the contributions: the final result is a list of the 
    # top `num_reasons` contributors for each data point and score
    contrib = [[c[1] for c in l] for l in contrib]
    contrib = [[
        tup for tup in
        sorted(enumerate(c), key = lambda tup: -tup[1])[:num_reasons]
        if tup[1] > 0
    ] for c in contrib]

    # Get the "splitlist" for each tree in the random forest & concatenate them
    splits = sum([splitlist(tree.tree_) for tree in rf.estimators_], [])
    
    # For each data point, get the list of tree nodes actually visited
    paths = hstack([tree.decision_path(X) for tree in rf.estimators_]).tocsr()
    
    # Find the reasons/explanations
    for n in range(num_reasons):
        reason = []
        for i, c in enumerate(contrib):
            if len(c) > n:
                line_thresh = [
                    t[1] for j,t          # Get each threshold
                    in enumerate(splits)  # from the list of splits
                    if t[0] == c[n][0]    # for this feature
                    and paths[i,j] != 0   # and each visited node
                    and t[2] > 0          # where the split had positive effect
                ]
                name = X.columns[c[n][0]] # The feature's name
                val = X.iloc[i, c[n][0]]  # The feature's value in this row
                
                # Get the lower and upper thresholds that contributed to the
                # score of the current row
                low = max([t for t in line_thresh if t < val], default = None)
                high = min([t for t in line_thresh if t > val], default = None)
                
                # Formulate the reason/explanation as a human-readable string
                if high is None and low is None: reason.append('%s' % name)
                elif high is None: reason.append('%s > %.2f' % (name, low))
                elif low is None: reason.append('%s <= %.2f' % (name, high))
                else: reason.append('%.2f < %s <= %.2f' % (low, name, high))
            else:
                reason.append('')
                
        pred_ex['REASON%d' % (n+1)] = reason
        
    return pred_ex


### Explanations for top cases predicted to belong to class 0

In [12]:
%%time
predict_explain(clf[0], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 680 ms


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
56,1.0,proline > 877.50,2.53 < flavanoids <= 3.08,0
53,1.0,proline > 877.50,2.53 < flavanoids <= 3.08,0
14,1.0,proline > 1050.00,2.53 < flavanoids <= 4.51,0
10,1.0,proline > 1050.00,2.53 < flavanoids <= 4.51,0
51,1.0,proline > 1050.00,2.55 < flavanoids <= 3.08,0
54,1.0,proline > 995.50,2.53 < flavanoids <= 3.08,0
15,1.0,proline > 995.50,2.53 < flavanoids <= 3.08,0
8,0.99,proline > 877.50,2.53 < flavanoids <= 3.08,0
48,0.99,proline > 877.50,2.53 < flavanoids <= 3.08,0
9,0.99,proline > 877.50,2.53 < flavanoids <= 4.51,0


### Explanations for top cases predicted to belong to class 1

In [13]:
%%time
predict_explain(clf[1], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 813 ms


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
80,1.0,alcohol <= 12.14,color_intensity <= 3.41,1
106,1.0,color_intensity <= 3.41,alcohol <= 12.40,1
89,1.0,alcohol <= 12.14,color_intensity <= 3.38,1
108,0.99,color_intensity <= 3.38,alcohol <= 12.47,1
88,0.98,color_intensity <= 3.38,alcohol <= 12.14,1
114,0.98,alcohol <= 12.14,color_intensity <= 3.41,1
93,0.98,color_intensity <= 3.41,alcohol <= 12.44,1
86,0.98,alcohol <= 12.40,color_intensity <= 3.35,1
97,0.98,color_intensity <= 3.41,alcohol <= 12.47,1
94,0.97,alcohol <= 12.14,color_intensity <= 3.38,1


### Explanations for top cases predicted to belong to class 2

In [14]:
%%time
predict_explain(clf[2], Xtest).assign(
    TRUE_CLASS = ytest
).sort_values('SCORE', ascending = False).head(26)

Wall time: 593 ms


Unnamed: 0,SCORE,REASON1,REASON2,TRUE_CLASS
148,0.99,0.66 < flavanoids <= 0.91,hue <= 0.67,2
164,0.97,flavanoids <= 0.91,hue <= 0.73,2
165,0.96,flavanoids <= 0.89,od280/od315_of_diluted_wines <= 1.76,2
163,0.96,hue <= 0.68,flavanoids <= 0.77,2
157,0.95,flavanoids <= 0.89,hue <= 0.67,2
147,0.95,hue <= 0.65,flavanoids <= 0.91,2
135,0.94,flavanoids <= 0.78,od280/od315_of_diluted_wines <= 1.80,2
154,0.94,hue <= 0.65,flavanoids <= 0.78,2
153,0.94,hue <= 0.68,flavanoids <= 0.91,2
151,0.91,hue <= 0.63,od280/od315_of_diluted_wines <= 1.58,2
