In [2]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

## Generate `sklearn` side:  Iris dataset

In [3]:
iris = load_iris()
feature_names = iris['feature_names']
X = pd.DataFrame(iris.data, columns=feature_names)
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

In [4]:
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=3,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=0, splitter='best')

In [5]:
node_count = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold

In [6]:
print(node_count, 'nodes in the tree')

5 nodes in the tree


In [7]:
feature

array([ 3, -2,  2, -2, -2])

## Put in graph format for cyto
nodes and edges

In [10]:
elements = []
for node in range(node_count):
    #nodes
    elements.append(
        {'data':
            {'id': str(node), 
             'label': str(feature_names[feature[node]])
            }
        } 
    )
    
    #edges
    if children_left[node] != -1:
        elements.append(
            {'data':
                {'source': node,
                 'target': children_left[node]
                }
            }
        )
    if children_right[node] != -1:
        elements.append(
            {'data':
                {'source': node,
                 'target': children_right[node]
                }
            }
        )

In [11]:
elements

[{'data': {'id': '0', 'label': 'petal width (cm)'}},
 {'data': {'source': 0, 'target': 1}},
 {'data': {'source': 0, 'target': 2}},
 {'data': {'id': '1', 'label': 'petal length (cm)'}},
 {'data': {'id': '2', 'label': 'petal length (cm)'}},
 {'data': {'source': 2, 'target': 3}},
 {'data': {'source': 2, 'target': 4}},
 {'data': {'id': '3', 'label': 'petal length (cm)'}},
 {'data': {'id': '4', 'label': 'petal length (cm)'}}]

## IMDB

In [10]:
from sklearn.feature_extraction.text import CountVectorizer

In [11]:
imdb = pd.read_csv('../data/imdb_labelled.txt', sep='\t')
imdb.columns = ['text', 'sentiment']

cv = CountVectorizer()
X = cv.fit_transform(imdb['text'])
y = imdb['sentiment']

feature_names = cv.get_feature_names()

estimator = DecisionTreeClassifier(max_leaf_nodes=10, random_state=0)
estimator.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=10,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=0, splitter='best')

In [24]:
# -2 means leaf
print(estimator.tree_.feature)
print(estimator.tree_.TREE_LEAF)

[ 221  124   -2    0 3012  103   -2  439   -2 1005   -2 2992   -2 1563
   -2   -2   -2   -2   -2]


AttributeError: 'sklearn.tree._tree.Tree' object has no attribute 'TREE_LEAF'

In [17]:
feature_names[-2]

'zombie'

In [20]:
leaf_id = estimator.apply(X)

In [21]:
leaf_id

array([15,  8, 15, 15, 15, 15, 17, 15, 15, 17, 15, 17, 15, 15, 17, 12, 12,
       15,  6, 10, 15, 17, 17, 15, 10, 17, 15, 15, 15, 15, 15, 15, 15, 17,
       17, 15, 17, 17, 15,  2,  2, 15, 15, 15, 17, 15, 17, 17, 17, 17, 15,
       17, 15, 15, 17, 15, 15, 16, 12, 17, 12, 17, 15, 10, 16, 15, 17, 15,
        6, 17, 15, 15, 12, 15, 15, 15, 17, 15, 17, 15, 17, 15, 17, 17, 17,
        2, 17, 15, 15, 15, 15, 15, 15, 15, 17, 15, 15, 15, 15, 15, 15, 17,
       15, 17,  2,  2, 15, 15, 17, 17, 15, 15, 12, 15, 15, 15, 15, 15, 15,
       15, 15, 15, 15, 15, 15,  2, 15, 15, 15, 15, 15, 15, 17, 15,  2,  2,
       17, 17, 17, 15, 15, 17, 15, 17, 12, 17, 15, 15,  2, 15, 15, 15,  2,
       15, 10, 17, 17, 10, 17, 17, 15, 15, 17, 15, 15, 17, 15, 17, 17, 15,
       15, 15, 15, 15, 15, 15, 17, 15, 17, 17, 17, 15, 17, 14, 17, 15, 17,
       12, 15, 17, 17, 17, 15, 17, 17, 17,  8, 15, 15, 17, 17, 17, 12, 17,
       17, 17, 15, 16, 15, 15, 15, 15, 15, 15, 15, 15, 17,  2, 15,  2, 15,
       15, 17, 17, 15, 17