In [1]:
import pickle, glob, json, os
import numpy as np
from copy import deepcopy
from matplotlib import pyplot as plt
from matplotlib import patches
from TreeEvaluater import dist_tree

## Tree Drawer

In [2]:
def prune(tree, i=0, current_depth=0, max_depth=3):
    if current_depth >= max_depth:
        return True
    flag_left = prune(tree, tree.children_left[i], current_depth+1, max_depth)
    flag_right = prune(tree, tree.children_right[i], current_depth+1, max_depth)
    left = tree.children_left[i]
    y_left = np.argmax(tree.value[left])
    right = tree.children_right[i]
    y_right = np.argmax(tree.value[right])
    if (y_left == y_right) and flag_left and flag_right:
        tree.children_left[i] = -2
        tree.children_right[i] = -2
        return True
    return False

In [3]:
class Drawer:
    def __init__(self, node_width=5, node_height=1, vertical_gap=1, canvas_width=12, fontsize=10):
        self.node_width = node_width
        self.node_height = node_height
        self.vertical_gap = vertical_gap
        self.canvas_width = canvas_width
        self.fontsize = fontsize

    def draw(self, ax, tree, base, max_depth):
        bbox = self.grow_node(ax, tree, base, True, -1, max_depth, 0, 0, self.canvas_width)
        bbox[1] = bbox[1] + 0.1 * self.node_width
        bbox[2] = bbox[2] - 0.1 * self.node_height
        ax.axis('off')
        return bbox

    def draw_internal_node(self, ax, tree, i, pos_x, pos_y, w):
        d = tree.feature[i]
        v = tree.threshold[i]
        ax.annotate('$x_{%d}\leq%d$' % (d+1, v), (pos_x + 0.5 * self.node_width, pos_y + 0.5 * self.node_height), fontsize=self.fontsize, ha='center', va='center', zorder=200)

    def draw_leaf_node(self, ax, tree, i, pos_x, pos_y, w):
        y = np.argmax(tree.value[i])
        py = tree.value[i] / sum(tree.value[i])
        t = '$y=%d$' % (y,)
        ax.annotate(t, (pos_x + 0.5 * self.node_width, pos_y + 0.5 * self.node_height), fontsize=self.fontsize, ha='center', va='center', zorder=200)

    def draw_edge(self, ax, tree, eq, i, pos_x, pos_y, w, side='l'):
        [s, t] = [-1, 'T'] if side == 'l' else [1, 'F']
        source_x = pos_x + 0.5 * self.node_width
        width_x = 0.5 * s * w
        gap_y = self.vertical_gap
        ax.plot([source_x, source_x + width_x], [pos_y, pos_y - gap_y], 'k-')
        ax.add_patch(
                patches.Ellipse(
                    (source_x + 0.5 * width_x, pos_y - 0.5 * gap_y),
                    0.2*self.node_width,
                    0.7*self.node_height,
                    edgecolor = 'k',
                    facecolor = 'w',
                    fill = True,
                    zorder=100
            ))
        if not eq:
            ax.add_patch(
                patches.Ellipse(
                    (source_x + 0.5 * width_x, pos_y - 0.5 * gap_y),
                    0.2*self.node_width,
                    0.7*self.node_height,
                    edgecolor = 'k',
                    facecolor = 'r',
                    alpha = 0.5,
                    fill = True,
                    zorder=100
            ))
        ax.annotate(t, (source_x + 0.5 * width_x, pos_y - 0.5 * gap_y), fontsize=self.fontsize, ha='center', va='center', zorder=200)

    def grow_node(self, ax, tree, base, eq, i, max_depth, pos_x, pos_y, w):
        d = tree.feature[i]
        v = tree.threshold[i]
        eq = eq * (d == base.feature[i]) * (v == base.threshold[i])
        eq = eq * ((tree.children_left[i] < 0) == (base.children_left[i] < 0))
        flag = True
        if (tree.children_left[i] >= 0) and (i < 2**max_depth - 1):
            pos_x_l, pos_x_r = pos_x - 0.5 * w, pos_x + 0.5 * w
            pos_y_next = pos_y - self.node_height - self.vertical_gap
            w_next = 0.5 * w
            bbox_left = self.grow_node(ax, tree, base, eq, tree.children_left[i], max_depth, pos_x_l, pos_y_next, w_next)
            bbox_right = self.grow_node(ax, tree, base, eq, tree.children_right[i], max_depth, pos_x_r, pos_y_next, w_next)
            self.draw_edge(ax, tree, eq, i, pos_x, pos_y, w, side='l')
            self.draw_edge(ax, tree, eq, i, pos_x, pos_y, w, side='r')
            flag = False
        if flag:
            if eq:
                ax.add_patch(
                    patches.Rectangle(
                        (pos_x+0.25*self.node_width, pos_y+0.0*self.node_height),
                        0.5*self.node_width,
                        self.node_height,
                        edgecolor = 'k',
                        facecolor = 'w',
                        fill = True
                ))
            else:
                ax.add_patch(
                    patches.Rectangle(
                        (pos_x+0.25*self.node_width, pos_y+0.0*self.node_height),
                        0.5*self.node_width,
                        self.node_height,
                        edgecolor = 'k',
                        facecolor = 'r',
                        alpha = 0.5,
                        fill = True
                ))
            self.draw_leaf_node(ax, tree, i, pos_x, pos_y, w)
            bbox = [pos_x, pos_x + self.node_width, pos_y, pos_y + self.node_height]
        else:
            if eq:
                ax.add_patch(
                    patches.Rectangle(
                        (pos_x, pos_y),
                        self.node_width,
                        self.node_height,
                        edgecolor = 'k',
                        facecolor = 'w',
                        fill = True
                ))
            else:
                ax.add_patch(
                    patches.Rectangle(
                        (pos_x, pos_y),
                        self.node_width,
                        self.node_height,
                        edgecolor = 'k',
                        facecolor = 'r',
                        alpha = 0.5,
                        fill = True
                ))
            self.draw_internal_node(ax, tree, i, pos_x, pos_y, w)
            bbox = [bbox_left[0], bbox_right[1], min(bbox_left[2], bbox_right[2]), pos_y + self.node_height]
        return bbox

## Rescaling Breast-Cancer Data

In [4]:
!wget -nc https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data

File ‘breast-cancer-wisconsin.data’ already there; not retrieving.



In [5]:
with open('./breast-cancer-wisconsin.data', 'r') as f:
    s = f.readlines()
z = []
for ss in s:
    u = [-1 if v=='?' else int(v) for v in ss.strip().split(',')]
    if -1 in u:
        continue
    z.append(u)
z = np.array(z)[:, :-1]

# scaling factors
m = np.min(z, axis=0)
s = np.max(z, axis=0) - np.min(z, axis=0)

## Figure 3

In [6]:
eps = '00316227'
seed, seed_tree = 0, 2
for method in ['greedy', 'stable']:

    # load
    if method == 'greedy':
        fn = '../res/breast_cancer/remove_043/breast_cancer_greedy_seed%02d_tree%02d.pkl' % (seed, seed_tree)
    elif method == 'stable':
        fn = '../res/breast_cancer/remove_043/breast_cancer_eps%s_seed%02d_tree%02d.pkl' % (eps, seed, seed_tree)
    with open(fn, 'rb') as f:
        res = pickle.load(f)
    depth_opt = 3

    # resacale & prune trees
    for t in res['trees']:
        for i, (u, v) in enumerate(zip(t.feature, t.threshold)):
            if u < 0:
                continue
            t.threshold[i] = int(np.round(v * s[u] + m[u]))
        prune(t, max_depth=depth_opt)
    
    # distance between trees
    d = np.zeros((100, 100))
    for i, t1 in enumerate(res['trees'][1:]):
        for j, t2 in enumerate(res['trees'][1:]):
            d[i, j] = dist_tree(t1, t2, max_depth=depth_opt)

    # cluster trees with zero distance
    dd = np.zeros_like(d)
    idx = np.arange(100)
    cluster = []
    while idx.size > 0:
        i = idx[0]
        j = np.where(d[i] == 0)[0]
        cluster.append(j)
        idx = np.setdiff1d(idx, j)

    # sort clusters along its size
    idx = np.argsort([c.size for c in cluster])[::-1]
    cluster = [cluster[i] for i in idx]

    # original tree using all the data
    fig_dir = '../fig/trees/'
    os.makedirs(fig_dir, exist_ok=True)
    plt.figure(figsize=(9, 5))
    drawer = Drawer(node_width=48, node_height=1, vertical_gap=1, canvas_width=100, fontsize=30)
    t0 = res['trees'][0]
    drawer.draw(plt.gca(), t0, t0, depth_opt)
    plt.tight_layout()
    plt.savefig(fig_dir+'%s_base.pdf' % (method,))

    plt.clf()
    plt.close(plt.gcf())

    # plot frequent trees
    for j, c in enumerate(cluster[:5]):
        plt.figure(figsize=(9, 5))
        drawer = Drawer(node_width=48, node_height=1, vertical_gap=1, canvas_width=100, fontsize=30)
        t0 = deepcopy(res['trees'][0])
        t = deepcopy(res['trees'][1:][c[0]])
        drawer.draw(plt.gca(), t, t0, depth_opt)
        plt.tight_layout()
        plt.savefig(fig_dir+'%s%02d_count%03d.pdf' % (method, j, c.size))

        plt.clf()
        plt.close(plt.gcf())