In [1]:
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import subprocess
from skimage import io
import os
import glob

In [2]:
if not os.path.exists('tree_diagrams/'):
    os.makedirs('tree_diagrams/')

In [3]:
np.random.seed(42)
data = load_iris()

X = data['data'][:, :2]
y = data['target']

x_train, x_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=42)


In [4]:
fn = data.feature_names[:2]
cn = data.target_names

In [5]:
max_depth=6
max_leaf_nodes=None
min_samples_split=2

In [6]:
def get_tree(max_depth, min_samples_split, max_leaf_nodes, DEBUG=False):
    clf = tree.DecisionTreeClassifier(max_depth=max_depth, min_samples_split=min_samples_split, max_leaf_nodes=max_leaf_nodes)
    clf = clf.fit(x_train, y_train)
    if DEBUG:
        fig, axes = plt.subplots(dpi=300)
        res = tree.plot_tree(clf,
                       ax=axes,
                       feature_names = fn, 
                       class_names=cn,
                       filled=True, rounded=True)
    tree.export_graphviz(clf, out_file=f'tree_diagrams/tree_{max_depth}_{min_samples_split}_{max_leaf_nodes}.dot', 
                      feature_names=data.feature_names[:2],  
                      class_names=data.target_names,  
                      filled=True, rounded=True,  
                      special_characters=True)

In [7]:
for d in range(1,7):
    for l in [3,6,9,12,15,18]:
        for s in [2,5,10,20,40]:
            get_tree(d, s, l)

In [8]:
flist = glob.glob("tree_diagrams/*.dot")

In [9]:
for fn in flist:
    dst = fn[:-4] + '.png'
    !dot -Tpng $fn -o  $dst