Condensing Cluster Trees
=============

The goal is to condense a cluster tree down to a simpler tree based on `min_cluster_size`. Essentially we wish to view a node of the tree to continue to exist until it is split into clusters of size at least `min_cluster_size`. When a split occurs that has fewer than `min_cluster_size` points in it we view this as the cluster "losing" points rather than splitting into a new cluster. We want to record when it lost the points, but we wish to retain the cluster identity.

To start we gegin with a new node class that supports multiple children rather than just "left" and "right". We also need to have an id, and a dist at which the node split off from the parent.

In [38]:
class CondensedTreeNode:    
    def __init__(self, id, dist, children, size, is_leaf):
        self.id = id
        self.dist = dist
        self.children = children
        self.child_size = size
        self.is_leaf = is_leaf
    
    def add_child(self, child):
        self.children.append(child)

    def __repr__(self):
        return '<Node object at %s>' % (
            hex(id(self))
            )
        
    def __str__(self):
        return "ID: %d, Lambda %d, Number of Children %d, " \
               "Number of children %d, Leaf node: %s" % (self.id, self.dist, len(self.children), 
                                                         self.child_size, self.is_leaf)
    

Next we'll need a utility function to extract out all the leaf nodes under a given cluster node. This is essential since we want to make a new "leaf-cluster" for each node that is smaller than `min_cluster_size` and we'll want to gather up allm th actual leaves/data-points and place them flat within that node. But hey, scipy is awsome and comes with a pre-order function that takes a function (and defaults to `lambda x: x.id`) so we can just pass that the identity and get all the leaves.

In [73]:
#This function consumes a tree and returns the set of leaves
def get_leaves(tree):
    """Consume a tree object and return a list of leaf nodes"""
    return tree.pre_order(lambda x: x)

Now we can focus on the condense operation. Since we are working with a tree this is most easily implemented as a recursive function walking down the tree. At each stage we check if the left or right branch are under the `min_cluster_size` and then either recursively call `condense_tree` or add all the leaves of the "left-cluster" accordingly. We'll label leaf nodes with "POINT" to denote that it is indexing a data point and not a cluster id.

In [116]:
def condense_tree(tree, min_cluster_size=10, next_id=0):
        
    #Verbose assert
    if tree.count == 0:
        print("Invalid input: Null tree")
        result = Node(-1, -1, [], -1, False)
        return result
    elif tree.count == 1:
        #Passed in a single node. only that node
        result = Node(-1, tree.dist, [], 1, True)
        return result
        
    result = CondensedTreeNode(next_id, tree.dist, [], tree.left.count + tree.right.count, 0)
        
    #If the left node is too small, add a leaf
    if tree.left.count <= min_cluster_size:
        leaves = get_leaves(tree.left)
        for leaf in leaves:
            result.add_child(CondensedTreeNode("POINT %i" % leaf.id, tree.left.dist, [], 1, True))
    elif tree.right.count <= min_cluster_size:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id)
        result.add_child(child)
    else:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id + 1)
        result.add_child(child)
            
    #If the right node is too small, add a leaf
    if tree.right.count <= min_cluster_size:
        leaves = get_leaves(tree.right)
        for leaf in leaves:
            result.add_child(CondensedTreeNode("POINT %i" % leaf.id, tree.right.dist, [], 1, True))
    elif tree.left.count <= min_cluster_size:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id)
        result.add_child(child)        
    else:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id + 1)
        result.add_child(child)
        
    return result, next_id
    

Now we can load up some test data (iris will do for now) and try this out.

In [20]:
import pandas as pd
import numpy as np
import scipy.spatial.distance as dist

iris = pd.read_csv("iris.csv")
distance_matrix = dist.squareform(dist.pdist(iris.ix[:,:4].as_matrix()))

In [21]:
def mutual_reachability_distance_matrix(distance_matrix, min_points):
    dim = distance_matrix.shape[0]
    core_distances = np.partition(distance_matrix, min_points, axis=0)[min_points]
    core_distance_matrix = core_distances.repeat(dim).reshape((dim,dim))
    result = np.dstack((core_distance_matrix, core_distance_matrix.T, distance_matrix)).max(axis=2)
    return result

In [22]:
mr_dist_matrix = mutual_reachability_distance_matrix(distance_matrix, 10)

In [23]:
import fastcluster
import scipy.cluster.hierarchy as hclust

In [24]:
ctree = fastcluster.single(mr_dist_matrix)

In [25]:
hctree = hclust.to_tree(ctree)

In [117]:
condensed, final_id = condense_tree(hctree, 10, 0)

Now we need to flatten the tree. The easiest way is to flatten a node into a list of parent children relations, and thn have a flatten tree function that recursively calls `flatten_node` down the whole tree.

In [119]:
def flatten_node(tree_node):
    return [(tree_node.id, x.id, tree_node.dist, x.child_size) for x in tree_node.children
             if tree_node.id != x.id]

In [120]:
def flatten_tree(tree):
    if tree.is_leaf:
        return []
    result = flatten_node(tree)
    for subtree in tree.children:
        result.extend(flatten_tree(subtree))
    return result

In [121]:
flatten_tree(condensed)

[(0, 1, 15.353690797474098, 50),
 (0, 4, 15.353690797474098, 100),
 (1, 'POINT 41', 2.9913846843382017, 1),
 (1, 'POINT 13', 2.6725405663175401, 1),
 (1, 'POINT 22', 2.6725405663175401, 1),
 (1, 'POINT 15', 2.257239278760959, 1),
 (1, 'POINT 8', 1.8764282443617202, 1),
 (1, 'POINT 38', 1.8764282443617202, 1),
 (1, 'POINT 42', 1.8764282443617202, 1),
 (1, 'POINT 5', 1.584309084792646, 1),
 (1, 'POINT 18', 1.584309084792646, 1),
 (1, 'POINT 44', 1.4481717299591841, 1),
 (1, 'POINT 14', 1.3709128075488768, 1),
 (1, 'POINT 33', 1.3709128075488768, 1),
 (1, 'POINT 23', 1.3370718298018256, 1),
 (1, 'POINT 24', 1.3370718298018256, 1),
 (1, 'POINT 20', 1.3370718298018256, 1),
 (1, 'POINT 31', 1.3370718298018256, 1),
 (1, 'POINT 26', 1.3370718298018256, 1),
 (1, 'POINT 43', 1.3370718298018256, 1),
 (1, 'POINT 16', 1.2246852859847495, 1),
 (1, 'POINT 32', 1.2246852859847495, 1),
 (1, 'POINT 36', 1.166842817904191, 1),
 (1, 2, 1.0373698111110854, 15),
 (1, 3, 1.0373698111110854, 14),
 (2, 'POINT 

Well that looks not entirely unreasonable. Let's have a quick check that we actually have all the POINT data (one for each original data point) with no duplication.

In [94]:
point_ids = [int(x[1].split()[1]) for x in flatten_tree(condensed) if isinstance(x[1], str)]

In [95]:
len(point_ids), len(set(point_ids))

(150, 150)

Success! Of course we actually want to make it a numpy array and if we have strings in there we're going to end up with quite a mess, so let's not do that. We can make a new version of condense tree that gives negative values to leaf nodes.

In [122]:
def condense_tree(tree, min_cluster_size=10, next_id=0):
        
    #Verbose assert
    if tree.count == 0:
        print("Invalid input: Null tree")
        result = Node(-1, -1, [], -1, False)
        return result
    elif tree.count == 1:
        #Passed in a single node. only that node
        result = Node(-1, tree.dist, [], 1, True)
        return result
        
    result = CondensedTreeNode(next_id, tree.dist, [], tree.left.count + tree.right.count, 0)
        
    #If the left node is too small, add a leaf
    if tree.left.count <= min_cluster_size:
        leaves = get_leaves(tree.left)
        for leaf in leaves:
            result.add_child(CondensedTreeNode(-leaf.id, tree.left.dist, [], 1, True))
    elif tree.right.count <= min_cluster_size:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id)
        result.add_child(child)
    else:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id + 1)
        result.add_child(child)
            
    #If the right node is too small, add a leaf
    if tree.right.count <= min_cluster_size:
        leaves = get_leaves(tree.right)
        for leaf in leaves:
            result.add_child(CondensedTreeNode(-leaf.id, tree.right.dist, [], 1, True))
    elif tree.left.count <= min_cluster_size:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id)
        result.add_child(child)        
    else:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id + 1)
        result.add_child(child)
        
    return result, next_id

In [123]:
condensed, final_id = condense_tree(hctree)
np.array(flatten_tree(condensed))

array([[   0.        ,    1.        ,   15.3536908 ,   50.        ],
       [   0.        ,    4.        ,   15.3536908 ,  100.        ],
       [   1.        ,  -41.        ,    2.99138468,    1.        ],
       [   1.        ,  -13.        ,    2.67254057,    1.        ],
       [   1.        ,  -22.        ,    2.67254057,    1.        ],
       [   1.        ,  -15.        ,    2.25723928,    1.        ],
       [   1.        ,   -8.        ,    1.87642824,    1.        ],
       [   1.        ,  -38.        ,    1.87642824,    1.        ],
       [   1.        ,  -42.        ,    1.87642824,    1.        ],
       [   1.        ,   -5.        ,    1.58430908,    1.        ],
       [   1.        ,  -18.        ,    1.58430908,    1.        ],
       [   1.        ,  -44.        ,    1.44817173,    1.        ],
       [   1.        ,  -14.        ,    1.37091281,    1.        ],
       [   1.        ,  -33.        ,    1.37091281,    1.        ],
       [   1.        ,  -23.      

In [104]:
hctree.dist

15.353690797474098

In [105]:
ctree[-5:]

array([[  41.        ,  293.        ,    2.99138468,   50.        ],
       [ 285.        ,  292.        ,    3.04254505,   95.        ],
       [ 106.        ,  295.        ,    4.36741503,   96.        ],
       [ 289.        ,  296.        ,    5.09319235,  100.        ],
       [ 294.        ,  297.        ,   15.3536908 ,  150.        ]])