In [1]:
import pandas as pd

from helpers import clean_and_backfill_data

from node import Node
from tree import Tree

In [2]:
male_wages = pd.read_csv('https://vincentarelbundock.github.io/Rdatasets/csv/plm/Males.csv').drop(['Unnamed: 0', 'nr'], axis=1)
male_wages = clean_and_backfill_data(male_wages)
male_wages.head()

Unnamed: 0,year,school,exper,union,ethn,married,health,wage,industry,occupation,residence
0,1980,14,1,False,other,False,False,1.19754,Business_and_Repair_Service,Service_Workers,north_east
1,1981,14,2,True,other,False,False,1.85306,Personal_Service,Service_Workers,north_east
2,1982,14,3,False,other,False,False,1.344462,Business_and_Repair_Service,Service_Workers,north_east
3,1983,14,4,False,other,False,False,1.433213,Business_and_Repair_Service,Service_Workers,north_east
4,1984,14,5,False,other,False,False,1.568125,Personal_Service,"Craftsmen, Foremen_and_kindred",north_east


### Instantiating Node Object

In [3]:
n = Node()
print(n, '\n')
_ = n.split_data(male_wages, 'wage')
print(n)

feature: None
split value: None
# Samples: 0
Avg y: inf 

feature: school
split value: 12.0
# Samples: 4360
Avg y: 1.649


### Training a Tree

In [4]:
t = Tree(min_samples_leaf=5, max_depth=10)
t.fit(male_wages, 'wage')

In [5]:
t.root_node

feature: school
split value: 12.0
# Samples: 4360
Avg y: 1.649

#### Traversing Tree to Group w/ Max Y

In [7]:
ct = 0
node = t.root_node
while not node.is_terminal:
    print('Split {}'.format(ct))
    print(node)
    if node.left_child.root_node.avg > node.right_child.root_node.avg:
        print("==> left branch")
        node = node.left_child.root_node
    else:
        node = node.right_child.root_node
        print("==> right branch")
    ct += 1
    print('\n')
    
print('Split {}'.format(ct))
print(node)

Split 0
feature: school
split value: 12.0
# Samples: 4360
Avg y: 1.649
==> right branch


Split 1
feature: year
split value: 1984.0
# Samples: 2888
Avg y: 1.741
==> right branch


Split 2
feature: school
split value: 13.0
# Samples: 1444
Avg y: 1.857
==> right branch


Split 3
feature: industry
split value: Trade
# Samples: 520
Avg y: 2.010
==> left branch


Split 4
feature: industry
split value: Agricultural
# Samples: 417
Avg y: 2.051
==> left branch


Split 5
feature: year
split value: 1985
# Samples: 406
Avg y: 2.065
==> right branch


Split 6
feature: industry
split value: Finance
# Samples: 305
Avg y: 2.114
==> right branch


Split 7
feature: residence
split value: north_east
# Samples: 32
Avg y: 2.339
==> right branch


Split 8
feature: school
split value: 14
# Samples: 8
Avg y: 2.689


### Compare w/ Scikit DecisionTreeRegressor

In [3]:
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import export_graphviz

from helpers import get_dummie_data

In [8]:
dummied_df = get_dummie_data(male_wages)
train_cols = list(filter(lambda col: col!='wage', dummied_df.columns))
X = dummied_df[train_cols]
y = dummied_df['wage']

dtr = DecisionTreeRegressor(min_samples_leaf=5, max_depth=10)
dtr.fit(X, y)

DecisionTreeRegressor(criterion='mse', max_depth=10, max_features=None,
           max_leaf_nodes=None, min_impurity_decrease=0.0,
           min_impurity_split=None, min_samples_leaf=5,
           min_samples_split=2, min_weight_fraction_leaf=0.0,
           presort=False, random_state=None, splitter='best')

In [12]:
export_graphviz(dtr, out_file='images/wage_tree.dot',
                feature_names = train_cols, max_depth=4)

! dot -Tpng images/wage_tree.dot -o images/wage_tree.png

##### Visualizing the Scikit Tree

<img src="images/wage_tree.png">