In [39]:
import os
import base64
from io import BytesIO
import pickle
from textwrap import dedent
import xml

import joblib
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn import metrics
import plotly.express as px
import pydot
import plotly.graph_objects as go

If you are running this notebook, make sure to download the file from AWS (see Readme).

In [6]:
df = pd.read_csv('clean_loan.csv')
df.columns = df.columns.str.lower()

In [8]:
X_df = df.drop(columns=['grade', 'int_rate'])
y = df['grade']

In [9]:
cat_cols = X_df.columns[X_df.dtypes == 'object']
num_cols = X_df.columns[X_df.dtypes != 'object']
X = pd.get_dummies(X_df, columns=cat_cols, prefix_sep=': ')

In [10]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=0
)

## Shallow model

In [34]:
%%time
model = DecisionTreeClassifier(max_depth=3, max_leaf_nodes=20, max_features='auto')
model.fit(X_train, y_train)

joblib.dump(model, open('assets/model-shallow.joblib', 'wb'))

CPU times: user 2.34 s, sys: 208 ms, total: 2.55 s
Wall time: 2.55 s


## Deep Model

In [64]:
%%time
model = DecisionTreeClassifier(max_depth=8, max_features='auto', max_leaf_nodes=100)
model.fit(X_train, y_train)

joblib.dump(model, open('assets/model-deep.joblib', 'wb'))

CPU times: user 4.27 s, sys: 164 ms, total: 4.43 s
Wall time: 4.43 s


## Entropy

In [60]:
%%time
model = DecisionTreeClassifier(criterion='entropy', max_depth=5, max_leaf_nodes=20, max_features='auto')
model.fit(X_train, y_train)

joblib.dump(model, open('assets/model-with-entropy.joblib', 'wb'))

CPU times: user 3.89 s, sys: 164 ms, total: 4.05 s
Wall time: 4.05 s


## Random Split

In [37]:
%%time
model = DecisionTreeClassifier(splitter='random', max_features='auto', max_leaf_nodes=20, max_depth=5)
model.fit(X_train, y_train)

joblib.dump(model, open('assets/model-random-split.joblib', 'wb'))

CPU times: user 2.33 s, sys: 160 ms, total: 2.49 s
Wall time: 2.49 s


## Save feature names

In [50]:
pickle.dump(X_test.columns.values, open('feature_names.pickle', 'wb'))

## Try tree vizualization

In [51]:
feature_names = pickle.load(open('feature_names.pickle', 'rb'))

In [100]:
def svg_to_fig(svg_bytes, title=None, plot_bgcolor='white', x_lock=False, y_lock=False):
    svg_enc = base64.b64encode(svg_bytes)
    svg = f'data:image/svg+xml;base64, {svg_enc.decode()}'
    
    # Get the width and height
    xml_tree = xml.etree.ElementTree.fromstring(svg_bytes.decode())
    img_width = int(xml_tree.attrib['width'].strip('pt'))
    img_height = int(xml_tree.attrib['height'].strip('pt'))

    fig = go.Figure()
    # Add invisible scatter trace.
    # This trace is added to help the autoresize logic work.
    fig.add_trace(
        go.Scatter(
            x=[0, img_width],
            y=[img_height, 0],
            mode="markers",
            marker_opacity=0,
            hoverinfo="none",
        )
    )
    fig.add_layout_image(
        dict(
            source=svg,
            x=0,
            y=0,
            xref="x",
            yref="y",
            sizex=img_width,
            sizey=img_height,
            opacity=1,
            layer="below",
        )
    )

    # Adapt axes to the right width and height, lock aspect ratio
    fig.update_xaxes(
        showgrid=False, 
        visible=False,
        range=[0, img_width]
    )
    fig.update_yaxes(
        showgrid=False,
        visible=False,
        range=[img_height, 0],
    )
    
    if x_lock is True:
        fig.update_xaxes(constrain='domain')
    if y_lock is True:
        fig.update_yaxes(
            scaleanchor="x",
            scaleratio=1
        )
    
    fig.update_layout(plot_bgcolor=plot_bgcolor)

    if title:
        fig.update_layout(title=title)

    return fig


In [92]:
path = 'assets/model-deep.joblib'
model = joblib.load(open(path, 'rb'))
dot_data = export_graphviz(
    model, 
    out_file=None, 
    filled=True, 
    rounded=True, 
    feature_names=feature_names,
    class_names=model.classes_,
    proportion=True,
    rotate=True,
    precision=2
)

pydot_graph = pydot.graph_from_dot_data(dot_data)[0]
svg_bytes = pydot_graph.create_svg()

In [99]:
fig = svg_to_fig(svg_bytes, title='Decision Tree Explanation', x_lock=True, y_lock=True)
fig.show()