In [1]:
import pandas as pd
import numpy as np
from dm_tools import data_prep

In [2]:
from sklearn.model_selection import train_test_split

In [3]:
df = data_prep()

y=df['TargetB']
X = df.drop(['TargetB'], axis=1)

In [4]:
rs= 10
X_mat=X.values
Xtr, Xtest, ytr, ytest = train_test_split(X_mat, y, test_size=0.3, stratify=y, random_state=rs)

In [5]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score

model = DecisionTreeClassifier(random_state=rs)
model.fit(Xtr, ytr)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=10,
            splitter='best')

In [6]:
model.score(Xtr, ytr)

1.0

In [7]:
y_pred = model.predict(Xtr)
print(classification_report(ytr, y_pred))

             precision    recall  f1-score   support

          0       1.00      1.00      1.00      3390
          1       1.00      1.00      1.00      3390

avg / total       1.00      1.00      1.00      6780



In [8]:
importances = model.feature_importances_
feature_names = X.columns

indices = np.argsort(importances) #gets indices of items in ascending
indices = np.flip(indices, axis=0) #flip for descending
indices = indices[:20]
for i in indices:
    print(feature_names[i], ':', importances[i])

DemMedHomeValue : 0.10371353665921498
DemMedIncome : 0.07677600801105956
GiftAvgAll : 0.06980988595894197
DemPctVeterans : 0.06613214928597241
DemAge : 0.060750495114459305
GiftTimeFirst : 0.04594492638340621
GiftAvgLast : 0.039119529776787676
GiftAvgCard36 : 0.0380219076090902
GiftTimeLast : 0.03754481714205862
PromCntAll : 0.03536307758528952
PromCnt36 : 0.0344873233968644
GiftAvg36 : 0.03333456207629517
PromCntCardAll : 0.031160032819215713
GiftCntCardAll : 0.029090632497530206
GiftCnt36 : 0.02902465415187723
GiftCntAll : 0.02844273638690637
PromCntCard36 : 0.02602984397479075
PromCnt12 : 0.02179582564122642
GiftCntCard36 : 0.014245411759881576
PromCntCard12 : 0.009727393864117933


In [9]:
import pydot
from io import StringIO
from sklearn.tree import export_graphviz

dotfile = StringIO()
export_graphviz(model, out_file=dotfile, feature_names=X.columns)
graph = pydot.graph_from_dot_data(dotfile.getvalue())
graph[0].write_png('week 3 dt viz.png')

In [10]:
from sklearn.model_selection import GridSearchCV
params = {'criterion': ['gini', 'entropy'],
         'max_depth': range(2, 7),
         'min_samples_leaf': range(20,60,10)}
cv = GridSearchCV(param_grid=params, estimator=DecisionTreeClassifier(), cv=10)
cv.fit(Xtr, ytr)

print('Train Accuracy:', cv.score(Xtr, ytr))
print('Test Accuracy:', cv.score(Xtest, ytest))

y_pred = cv.predict(Xtest)
print(classification_report(ytest, y_pred))
print(cv.best_params_)

Train Accuracy: 0.5941002949852507
Test Accuracy: 0.5750172057811425
             precision    recall  f1-score   support

          0       0.57      0.63      0.60      1453
          1       0.58      0.52      0.55      1453

avg / total       0.58      0.58      0.57      2906

{'criterion': 'gini', 'max_depth': 5, 'min_samples_leaf': 50}


In [11]:
from dm_tools import analyse_feature_importance, visualize_decision_tree

In [12]:
analyse_feature_importance(cv.best_estimator_, X.columns, 20)
visualize_decision_tree(cv.best_estimator_, X.columns, "optimal_tree.png")

GiftCnt36 : 0.32372028991803176
DemMedHomeValue : 0.16870399488848986
GiftAvgLast : 0.13715027943543454
GiftTimeLast : 0.07216895789175108
StatusCatStarAll : 0.04624807594062364
GiftCntAll : 0.04526431066083868
GiftCntCardAll : 0.04444007634482767
PromCntCardAll : 0.038352830083886735
DemPctVeterans : 0.029606101805632952
PromCnt36 : 0.018147095366402428
GiftAvgAll : 0.017372969310631516
StatusCat96NK_A : 0.016862957062736873
GiftTimeFirst : 0.014129096904096769
GiftAvgCard36 : 0.013209693354314394
PromCntCard12 : 0.00872994822258077
PromCnt12 : 0.005893322809720318
DemCluster_13 : 0.0
DemCluster_10 : 0.0
DemCluster_11 : 0.0
DemCluster_12 : 0.0
