Simple example showing how to train and visualize a classification tree.

In [None]:
%matplotlib inline
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

Let us load the Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris). Note that we only consider two features (defined via the 'pair' list).

In [None]:
from sklearn.datasets import load_iris

pair = [0,1]
iris = load_iris()

print(type(iris))
print(dir(iris))

X = iris.data[:, pair]
y = iris.target

#print(iris.DESCR)
print(iris.filename)
print(iris.target_names)

Let us fit a classification tree using the popular Scikit-Learn library!

In [None]:
from sklearn.tree import DecisionTreeClassifier

# instantiate the model
model = DecisionTreeClassifier(max_depth=2)

# train the model
model.fit(X,y)

Next, let us visualize the final classification surface/boundaries. Since we are only considering two features, we can generate a two-dimensional plot (see http://scikit-learn.org/stable/auto_examples/tree/plot_iris.html).

In [None]:
plot_step = 0.02
n_classes = 3
plot_colors = "ryb"

plt.figure()

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                     np.arange(y_min, y_max, plot_step))
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)

Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)

plt.xlabel(iris.feature_names[pair[0]])
plt.ylabel(iris.feature_names[pair[1]])

# Plot the training points
for i, color in zip(range(n_classes), plot_colors):
    idx = np.where(y == i)
    plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
                cmap=plt.cm.RdYlBu, edgecolor='black', s=15)
plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")
plt.show()    

One can also visualize the induced tree structure. **Note that you need the graphviz library and package being installed on your system.**

In [None]:
import graphviz 
from sklearn import tree

dot_data = tree.export_graphviz(model, 
                                out_file=None, 
                                feature_names=iris.feature_names[:2],  
                                class_names=iris.target_names,  
                                filled=True, 
                                rounded=True,  
                                special_characters=True) 
graph = graphviz.Source(dot_data) 
graph