Condensing Cluster Trees
=============

The goal is to condense a cluster tree down to a simpler tree based on `min_cluster_size`. Essentially we wish to view a node of the tree to continue to exist until it is split into clusters of size at least `min_cluster_size`. When a split occurs that has fewer than `min_cluster_size` points in it we view this as the cluster "losing" points rather than splitting into a new cluster. We want to record when it lost the points, but we wish to retain the cluster identity.

To start we gegin with a new node class that supports multiple children rather than just "left" and "right". We also need to have an id, and a dist at which the node split off from the parent.

In [38]:
class CondensedTreeNode:    
    def __init__(self, id, dist, children, size, is_leaf):
        self.id = id
        self.dist = dist
        self.children = children
        self.child_size = size
        self.is_leaf = is_leaf
    
    def add_child(self, child):
        self.children.append(child)

    def __repr__(self):
        return '<Node object at %s>' % (
            hex(id(self))
            )
        
    def __str__(self):
        return "ID: %d, Lambda %d, Number of Children %d, " \
               "Number of children %d, Leaf node: %s" % (self.id, self.dist, len(self.children), 
                                                         self.child_size, self.is_leaf)
    

Next we'll need a utility function to extract out all the leaf nodes under a given cluster node. This is essential since we want to make a new "leaf-cluster" for each node that is smaller than `min_cluster_size` and we'll want to gather up allm th actual leaves/data-points and place them flat within that node. But hey, scipy is awsome and comes with a pre-order function that takes a function (and defaults to `lambda x: x.id`) so we can just pass that the identity and get all the leaves.

In [73]:
#This function consumes a tree and returns the set of leaves
def get_leaves(tree):
    """Consume a tree object and return a list of leaf nodes"""
    return tree.pre_order(lambda x: x)

Now we can focus on the condense operation. Since we are working with a tree this is most easily implemented as a recursive function walking down the tree. At each stage we check if the left or right branch are under the `min_cluster_size` and then either recursively call `condense_tree` or add all the leaves of the "left-cluster" accordingly. We'll label leaf nodes with "POINT" to denote that it is indexing a data point and not a cluster id.

In [116]:
def condense_tree(tree, min_cluster_size=10, next_id=0):
        
    #Verbose assert
    if tree.count == 0:
        print("Invalid input: Null tree")
        result = Node(-1, -1, [], -1, False)
        return result
    elif tree.count == 1:
        #Passed in a single node. only that node
        result = Node(-1, tree.dist, [], 1, True)
        return result
        
    result = CondensedTreeNode(next_id, tree.dist, [], tree.left.count + tree.right.count, 0)
        
    #If the left node is too small, add a leaf
    if tree.left.count <= min_cluster_size:
        leaves = get_leaves(tree.left)
        for leaf in leaves:
            result.add_child(CondensedTreeNode("POINT %i" % leaf.id, tree.left.dist, [], 1, True))
    elif tree.right.count <= min_cluster_size:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id)
        result.add_child(child)
    else:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id + 1)
        result.add_child(child)
            
    #If the right node is too small, add a leaf
    if tree.right.count <= min_cluster_size:
        leaves = get_leaves(tree.right)
        for leaf in leaves:
            result.add_child(CondensedTreeNode("POINT %i" % leaf.id, tree.right.dist, [], 1, True))
    elif tree.left.count <= min_cluster_size:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id)
        result.add_child(child)        
    else:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id + 1)
        result.add_child(child)
        
    return result, next_id
    

Now we can load up some test data (iris will do for now) and try this out.

In [20]:
import pandas as pd
import numpy as np
import scipy.spatial.distance as dist

iris = pd.read_csv("iris.csv")
distance_matrix = dist.squareform(dist.pdist(iris.ix[:,:4].as_matrix()))

In [21]:
def mutual_reachability_distance_matrix(distance_matrix, min_points):
    dim = distance_matrix.shape[0]
    core_distances = np.partition(distance_matrix, min_points, axis=0)[min_points]
    core_distance_matrix = core_distances.repeat(dim).reshape((dim,dim))
    result = np.dstack((core_distance_matrix, core_distance_matrix.T, distance_matrix)).max(axis=2)
    return result

In [22]:
mr_dist_matrix = mutual_reachability_distance_matrix(distance_matrix, 10)

In [23]:
import fastcluster
import scipy.cluster.hierarchy as hclust

In [24]:
ctree = fastcluster.single(mr_dist_matrix)

In [25]:
hctree = hclust.to_tree(ctree)

In [117]:
condensed, final_id = condense_tree(hctree, 10, 0)

Now we need to flatten the tree. The easiest way is to flatten a node into a list of parent children relations, and thn have a flatten tree function that recursively calls `flatten_node` down the whole tree.

In [124]:
def flatten_node(tree_node):
    return [(tree_node.id, x.id, 1.0/tree_node.dist, x.child_size) for x in tree_node.children
             if tree_node.id != x.id]

In [141]:
def flatten_tree_recursion(tree):
    if tree.is_leaf:
        return []
    result = flatten_node(tree)
    for subtree in tree.children:
        result.extend(flatten_tree_recursion(subtree))
    return result

def flatten_tree(tree):
    result = flatten_tree_recursion(tree) + [(0, 0, 0.0, tree.child_size)]
    return pd.DataFrame(result, columns=("parent","child","lambda","child_size"))

In [130]:
flatten_tree(condensed)

Unnamed: 0,parent,child,lambda,child_size
0,0,1,0.065131,50
1,0,4,0.065131,100
2,1,-41,0.334293,1
3,1,-13,0.374176,1
4,1,-22,0.374176,1
5,1,-15,0.443019,1
6,1,-8,0.532927,1
7,1,-38,0.532927,1
8,1,-42,0.532927,1
9,1,-5,0.631190,1


In [136]:
def condense_tree(tree, min_cluster_size=10, next_id=0):
        
    #Verbose assert
    if tree.count == 0:
        print("Invalid input: Null tree")
        result = Node(-1, -1, [], -1, False)
        return result
    elif tree.count == 1:
        #Passed in a single node. only that node
        result = Node(-1, tree.dist, [], 1, True)
        return result
        
    result = CondensedTreeNode(next_id, tree.dist, [], tree.left.count + tree.right.count, 0)
        
    #If the left node is too small, add a leaf
    if tree.left.count <= min_cluster_size:
        leaves = get_leaves(tree.left)
        for leaf in leaves:
            result.add_child(CondensedTreeNode(-(leaf.id+1), tree.left.dist, [], 1, True))
    elif tree.right.count <= min_cluster_size:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id)
        result.add_child(child)
    else:
        child, next_id = condense_tree(tree.left, min_cluster_size, next_id + 1)
        result.add_child(child)
            
    #If the right node is too small, add a leaf
    if tree.right.count <= min_cluster_size:
        leaves = get_leaves(tree.right)
        for leaf in leaves:
            result.add_child(CondensedTreeNode(-(leaf.id+1), tree.right.dist, [], 1, True))
    elif tree.left.count <= min_cluster_size:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id)
        result.add_child(child)        
    else:
        child, next_id = condense_tree(tree.right, min_cluster_size, next_id + 1)
        result.add_child(child)
        
    return result, next_id

In [142]:
condensed, final_id = condense_tree(hctree)
ctree = flatten_tree(condensed)
ctree.head()

Unnamed: 0,parent,child,lambda,child_size
0,0,1,0.065131,50
1,0,4,0.065131,100
2,1,-42,0.334293,1
3,1,-14,0.374176,1
4,1,-23,0.374176,1


In [145]:
births = ctree.groupby("child").min()[["lambda"]]
births.head()

Unnamed: 0_level_0,lambda
child,Unnamed: 1_level_1
-150,0.729562
-149,0.624985
-148,0.624985
-147,0.688514
-146,0.624985


In [147]:
births.loc[7]

lambda    0.539257
Name: 7, dtype: float64

In [153]:
joined_table = ctree.join(births, on="parent", lsuffix="_death", rsuffix="_birth")
joined_table.head()

Unnamed: 0,parent,child,lambda_death,child_size,lambda_birth
0,0,1,0.065131,50,0.0
1,0,4,0.065131,100,0.0
2,1,-42,0.334293,1,0.065131
3,1,-14,0.374176,1,0.065131
4,1,-23,0.374176,1,0.065131


In [180]:
joined_table["stability"] = joined_table.apply(lambda row: 
                                               (row["lambda_death"] - 
                                                row["lambda_birth"]) * 
                                               row["child_size"], axis=1)
joined_table.groupby("parent")[["stability"]].sum()/pd.DataFrame(joined_table.parent.value_counts(), columns=["stability"])

Unnamed: 0,stability
0,3.256546
1,1.661413
2,0.378264
3,0.457217
4,2.00822
5,0.155334
6,0.401722
7,0.075426
8,0.10903


In [179]:
pd.DataFrame(joined_table.parent.value_counts(), columns=["stability"])

Unnamed: 0,stability
8,28
7,25
5,25
1,23
4,20
2,15
3,14
6,6
0,3


In [159]:
ctree[ctree.child == 1]

Unnamed: 0,parent,child,lambda,child_size
0,0,1,0.065131,50


In [161]:
ctree.parent.unique()

array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int64)

In [181]:
def stability(row):
    return (row["lambda_death"] - row["lambda_birth"]) * row["child_size"]

def compute_stability(cluster_tree):
    births = cluster_tree.groupby("child").min()[["lambda"]]
    births_and_deaths = cluster_tree.join(births, on="parent", lsuffix="_death", rsuffix="_birth")
    births_and_deaths["stability"] = births_and_deaths.apply(stability, axis=1)
    return births_and_deaths.groupby("parent")[["stability"]].sum() / \
            pd.DataFrame(births_and_deaths.parent.value_counts(), columns=["stability"])

In [184]:
stability_table = compute_stability(ctree)
stability_table

Unnamed: 0,stability
0,3.256546
1,1.661413
2,0.378264
3,0.457217
4,2.00822
5,0.155334
6,0.401722
7,0.075426
8,0.10903


In [190]:
def is_stable_cluster(tree_node, stability_table):
    child_stability = sum(stability_table.loc[x.id] for x in tree_node.children if not x.is_leaf)
    return (stability_table.loc[tree_node.id] > child_stability)[0]

In [191]:
is_stable_cluster(condensed, stability_table)

False

In [255]:
def stability_score(tree_node, stability_table):
    node_stability = stability_table.loc[tree_node.id][0]
    child_stability = sum(max(stability_score(x, stability_table)) for x in tree_node.children if not x.is_leaf)
    return (node_stability, child_stability)

def get_clusters(tree_node, stability_table, results={}):
    tree_node.score = stability_score(tree_node, stability_table)
    if tree_node.score[0] > tree_node.score[1]:
        tree_node.is_cluster = True
        cluster_id = max(results.keys()) + 1 if results.keys() else 0 
        results[cluster_id] = tree_node.points
    else:
        tree_node.is_cluster = False
    if not tree_node.is_cluster:
        for node in tree_node.children:
            if not node.is_leaf:
                get_clusters(node, stability_table, results)
    return results

def get_leaf_point_ids(tree):
    results = []
    for node in tree.children:
        if node.is_leaf:
            results.append(-(node.id - 1))
        else:
            results.extend(get_leaf_point_ids(node))
    return results

def reduce_tree(tree):
    result = CondensedTreeNode(tree.id, 0, [], 0, 0)
    result.points = get_leaf_point_ids(tree)
    
    children_to_process = tree.children[:]
    
    for child in children_to_process:
        if child.is_leaf:
            continue
        if child.id == tree.id:
            children_to_process.extend(child.children)
        else:
            result.children.append(reduce_tree(child))
            
    return result

In [254]:
reduced = reduce_tree(condensed)
get_clusters(reduced, stability_table)

{0: [43,
  15,
  24,
  17,
  10,
  40,
  44,
  7,
  20,
  46,
  16,
  35,
  18,
  34,
  38,
  51,
  13,
  2,
  19,
  30,
  6,
  42,
  9,
  41,
  29,
  21,
  23,
  12,
  48,
  50,
  27,
  37,
  8,
  4,
  49,
  31,
  32,
  14,
  39,
  11,
  36,
  5,
  3,
  47,
  25,
  26,
  22,
  33,
  28,
  45],
 1: [62,
  100,
  59,
  95,
  108,
  120,
  107,
  124,
  119,
  133,
  137,
  111,
  109,
  132,
  70,
  89,
  66,
  76,
  99,
  63,
  86,
  57,
  68,
  64,
  81,
  92,
  73,
  90,
  97,
  98,
  96,
  101,
  84,
  69,
  94,
  83,
  61,
  82,
  71,
  55,
  91,
  131,
  110,
  102,
  52,
  138,
  106,
  126,
  142,
  122,
  146,
  127,
  104,
  145,
  113,
  112,
  149,
  105,
  118,
  139,
  150,
  143,
  117,
  147,
  130,
  134,
  114,
  141,
  54,
  79,
  87,
  78,
  88,
  80,
  75,
  65,
  93,
  58,
  56,
  60,
  53,
  67,
  77,
  121,
  148,
  123,
  74,
  115,
  103,
  144,
  85,
  151,
  125,
  135,
  128,
  129,
  72,
  140,
  116,
  136]}