In [1]:
from sklearn import tree
import pandas as pd
from sklearn import model_selection
import matplotlib.pyplot as plt
import graphviz

## Build a Decision Tree Classifier

The following method accepts a few model and hyper-parameter values and returns a decision tree.

In [2]:
def dt_train(X, y, **kwargs):
    max_depth = kwargs.get('max_depth', 10)
    random_state = kwargs.get('random_state',7)
    criterion = kwargs.get('criterion', 'gini')
    min_samples_leaf = kwargs.get('min_samples_leaf', 1.0)
    
    dt = tree.DecisionTreeClassifier(max_depth=max_depth, random_state=random_state)
    dt.fit(X, y)
    return dt

The following method accepts a decision tree classification model and perform post pruning

In [3]:
def prune(dt, X, y):
    path = dt.cost_complexity_pruning_path(X, y)
    return path

## Utilities for analysis

The following method builds a validation curve for a decision tree classification model based on a set of model and hyperparameter values. 

In [4]:
def plot_validation_curve(dt, X, y, param_name, param_range, cross_validation, fig_name):
    tr_sc, tst_sc = model_selection.validation_curve(clf_dt, X, y, param_name=param_name,
                                                            param_range=param_range, cv=cross_validation)
    plt.figure()
    plt.xticks(param_range)
    
    plt.xlabel(param_name)
    plt.ylabel("Score")
    plt.legend(loc="best")
    
    plt.plot(param_range, np.mean(tr_sc, axis=1), label='Training score')
    plt.plot(param_range, np.mean(tst_sc, axis=1), label='Cross-validation score')
    plt.title('Decision Tree - Validation Curve')
    
    plt.savefig(fig_name)
    plt.show()


Get classification score

In [5]:
def get_classification_score(dt, X, y):
    return dt.score(X, y)

The following function provides a visual representation for the decision tree

In [6]:
def visualize_dt(dt, feature_names, target_names, file_name):
    exp_tree = tree.export_graphviz(dt, out_file=file_name, 
                      feature_names=feature_names,  
                      class_names=target_names,  
                      filled=True, rounded=True,  
                      special_characters=True)
    graphviz.source(exp_tree)
    

## Bank Marketing Data Analysis
The data is related with direct marketing campaigns (phone calls) of a Portuguese banking institution. The classification goal is to predict if the client will subscribe a term deposit (variable y)