# Model Visualization

In [1]:
from IPython.display import Image
import joblib
from sklearn import tree as sklearn_tree
from six import StringIO
from tqdm import tqdm
import pandas as pd
import pydotplus

In [2]:
def visualize_tree(model, feature_names, class_names, fill_colors, progress=False):
    dot_data = StringIO()
    if progress:
        print("Exporting graphviz...")
    sklearn_tree.export_graphviz(
        model,
        out_file=dot_data,
        feature_names=feature_names,
        class_names=class_names
    )
    if progress:
        print("Building graph...")
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    leaves = set()
    non_leaves = set()
    if progress:
        print("Processing edges...")
    edge_bar = tqdm(graph.get_edges()) if progress else graph.get_edges()
    for edge in edge_bar:
        leaves.add(int(edge.get_destination()))
        non_leaves.add(int(edge.get_source()))
    leaves.difference_update(non_leaves)
    leaves_idx = list(filter(lambda l: type(l) is int, leaves))
    nodes = graph.get_nodes()
    terminal_nodes = [nodes[i + 1] for i in leaves_idx]
    note = "Catch Probability\n{:.1f}% ({} / {})"
    if progress:
        print("Processing nodes...")
    node_bar = tqdm(nodes) if progress else nodes
    for i, node in enumerate(node_bar):
        text = node.get_label()
        is_terminal = (i - 1) in leaves_idx
        if text:
            if is_terminal:
                cls = text[1:-1].split("nclass = ")[1]
                raw_vals = text[1:-1].split("value = [")[1].split("]")[0].split(", ")
                vals = [int(v) for v in raw_vals]
                comp = vals[1]
                att = sum(vals)
                prob = 100 * (comp / att)
                node.set("label", note.format(prob, comp, att))
                node.set("style", "filled")
                if cls == class_names[0]:
                    node.set("fillcolor", fill_colors[0])
                elif cls == class_names[1]:
                    node.set("fillcolor", fill_colors[1])
            else:
                rule = text[1:-1].split("gini")[0]
                node.set("label", rule)
    if progress:
        print("Done.")
    return Image(graph.create_png())

In [3]:
clf_tree = joblib.load("../models/demo_DecisionTree.pkl")
clf_tree

DecisionTreeClassifier(random_state=0)

In [4]:
classes = ["Non-Goal", "Goal"]
colors = ["#f76b5c", "#5cff87"]
features = ["goal_distance", "goal_angle"]
img = visualize_tree(clf_tree, features, classes, colors, progress=True)

Exporting graphviz...
Building graph...
Processing edges...


100%|██████████| 3302/3302 [00:00<00:00, 332890.87it/s]
100%|██████████| 3304/3304 [00:00<00:00, 42580.59it/s]

Processing nodes...
Done.





In [5]:
with open("../models/demo_DecisionTree.png", "wb") as f:
    f.write(img.data)

In [6]:
def visualize_logit(model, features, cmap):
    variables = ["intercept"] + features
    coefs = list(model.intercept_) + list(model.coef_[0])
    df = pd.DataFrame({"feature": variables, "coef": coefs})
    return df.style.background_gradient(cmap=cmap)

In [7]:
clf_logit = joblib.load("../models/demo_logistic_regression.pkl")
clf_logit

LogisticRegression(random_state=0)

In [8]:
visualize_logit(clf_logit, features, cmap="RdYlBu")

Unnamed: 0,feature,coef
0,intercept,-2.144355
1,goal_distance,-0.007052
2,goal_angle,1.784609
