In [1]:
import numpy as np
import pandas as pd
from Branch import Branch
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from ModelGenerator import ModelGenerator
from sklearn.tree import export_graphviz
import pydot
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
pydot.graph_from_dot_data(open('tree.dot','r').read())

[<pydot.Dot at 0x1a7afac19e8>]

In [3]:
iris = load_iris()
data = pd.DataFrame(iris.data[:],columns=iris.feature_names)
data['class'] = iris.target


In [4]:
rf=RandomForestClassifier(n_estimators=2)
rf.fit(data[iris.feature_names],iris.target)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=2, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False)

In [5]:
trees=rf.estimators_
mg=ModelGenerator(iris.feature_names,iris.target_names,rf)

In [6]:
mg.generateBranches()

In [7]:
mg.branches_lists[0][0].toString()

'petal length (cm) <= 2.45, labels: [setosa : 1.0 versicolor : 0.0 virginica : 0.0 ] Number of samples: 37'

In [8]:
mg.buildConjunctionSet()

Iteration 1: 8 trees


In [9]:
for b in mg.conjunctionSet:
    print(b.toString())
    print("")

petal length (cm) <= 2.45, petal width (cm) <= 0.8, labels: [setosa : 1.0 versicolor : 0.0 virginica : 0.0 ] Number of samples: 35.4682957019

petal width (cm) > 0.8, petal length (cm) <= 2.45, labels: [setosa : 0.5 versicolor : 0.5 virginica : 0.0 ] Number of samples: 32.7566787083

petal length (cm) > 2.45, petal length (cm) <= 4.95, petal width (cm) <= 0.8, labels: [setosa : 0.5 versicolor : 0.5 virginica : 0.0 ] Number of samples: 34.4963766213

petal length (cm) > 2.45, petal width (cm) > 0.8, petal length (cm) <= 4.75, petal width (cm) <= 1.75, labels: [setosa : 0.0 versicolor : 1.0 virginica : 0.0 ] Number of samples: 31.8590646441

petal length (cm) > 4.75, petal width (cm) > 0.8, sepal length (cm) <= 6.25, sepal width (cm) <= 3.1, petal length (cm) <= 4.95, petal width (cm) <= 1.75, labels: [setosa : 0.0 versicolor : 0.5 virginica : 0.5 ] Number of samples: 11.8321595662

sepal length (cm) > 6.25, petal length (cm) > 4.75, petal width (cm) > 0.8, sepal length (cm) <= 6.5, sepa

In [10]:
records=[]
for b in mg.conjunctionSet:
    records.extend(b.get_branch_records())


In [11]:
data=pd.DataFrame(records).fillna(0)

Unnamed: 0,0<=5.85,0<=5.95,0<=6.25,0<=6.5,0>5.85,0>5.95,0>6.25,0>6.5,1<=2.6,1<=2.9,...,3<=1.65,3<=1.7,3<=1.75,3>0.8,3>1.65,3>1.7,3>1.75,label,num_of_samples,weight
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,setosa,35.468296,1.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,versicolor,35.468296,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,virginica,35.468296,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,setosa,32.756679,0.5
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,versicolor,32.756679,0.5
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,virginica,32.756679,0.0
6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,setosa,34.496377,0.5
7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,versicolor,34.496377,0.5
8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,virginica,34.496377,0.0
9,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,1.0,0.0,0.0,0.0,setosa,31.859065,0.0


In [12]:
est=rf.estimators_[0]
t=est.tree_
export_graphviz(t,out_file='tree.dot',feature_names=iris.feature_names,class_names=iris.target_names,  
                         filled=True, rounded=True)