The focus of this notebook is on the development of a high performance method to retain only a subset of tips within a tree. It can be used to pull a subtree with the constraint that the subtree spans some tips and is not restricted to internal nodes. This method operates on an array-based postorder representation of a tree as produced by [`to_array`](http://scikit-bio.org/docs/0.4.1/generated/skbio.tree.TreeNode.to_array.html#skbio.tree.TreeNode.to_array). 

The motivation for this notebook is to minimize the time and space requirements for grabbing a subtree directly from the array representation. These arrays are the primary data structure used by Fast UniFrac. In developing a parallized version of Fast UniFrac, it became apparent that this operation would be performed many times and the space and time needed for the implicit copy of the tree were a concern. The performance of Fast UniFrac is tightly tied to the number of nodes within the tree being operated on, and is thus being able to minimize the number of nodes represented is critical for its performance. 

scikit-bio's TreeNode does contain a [`shear`](http://scikit-bio.org/docs/0.4.1/generated/skbio.tree.TreeNode.shear.html#skbio.tree.TreeNode.shear), however the method only operates on `TreeNode` and is very expensive in time and space on large trees (i.e., millions of tips) as it must make a copy of the tree. ete2 also provides a comparable method, [`prune`](https://github.com/jhcepas/ete/blob/cd1bdb532535194b714ed0a62580e3517e4728d4/ete3/coretype/tree.py#L431), however this method operates in-place on the tree and thus is not suitable for situations where multiple subtrees must be derived. Both ete2 and scikit-bio will be subject to large memory requirements relative a tree array, on the order of 10s of GB to 10s of MB) due to the rich nature of the objects (which is generally a very very good thing).

In [1]:
def shear(indexed, to_keep):
    """Shear off nodes from a tree array
    
    Parameters
    ----------
    indexed : dict
        The result of TreeNode.to_array
    to_keep : set
        The tip IDs of the tree to keep
        
    Returns
    -------
    dict
        A TreeNode.to_array like dict with the exception that "id_index" is not
        provided, and any extraneous attributes formerly included are not 
        passed on.
    
    Notes
    -----
    Unlike TreeNode.shear, this method does not prune (i.e., collapse single
    descendent nodes). This is an open development target.
    
    This method assumes that to_keep is a subset of names in the tree.
    
    The order of the nodes remains unchanged.
    """
    # nodes to keep mask
    mask = np.zeros(len(indexed['id']), dtype=np.bool)

    # set any tips marked "to_keep"
    tips_to_keep = [i for i, n in enumerate(indexed['name']) if n in to_keep]
    mask[np.asarray(tips_to_keep)] = True

    # perform a post-order traversal and identify any nodes that should be 
    # retained
    new_child_index = []
    for node_idx, child_left, child_right in indexed['child_index']:
        being_kept = mask[child_left:child_right + 1]

        # NOTE: the second clause is an explicit test to keep the root node. This 
        # may not be necessary and may be a remenant of mucking around.
        if being_kept.sum() >= 1 or node_idx == indexed['id'][-1]:
            mask[node_idx] = True

    # we now know what nodes to keep, so we can create new IDs for assignment
    new_ids = np.arange(mask.sum(), dtype=int)
    
    # construct a map that associates old node IDs to the new IDs
    id_map = {i_old: i_new for i_old, i_new in zip(indexed['id'][mask], new_ids)}

    # perform another post-order traversal to construct the new child index arrays
    # which provide index positions of the desecendents of a given internal node.
    for node_idx, child_left, child_right in indexed['child_index']:
        being_kept = mask[child_left:child_right + 1]

        # NOTE: the second clause is an explicit test to keep the root node. This 
        # may not be necessary and may be a remenant of mucking around.
        if being_kept.sum() >= 1 or node_idx == indexed['id'][-1]:
            new_id = id_map[node_idx]
            child_indices = indexed['id'][child_left:child_right + 1][being_kept]
            left_child = id_map[child_indices[0]]
            right_child = id_map[child_indices[-1]]
            new_child_index.append([new_id, left_child, right_child])

    new_child_index = np.asarray(new_child_index)

    return {'child_index': new_child_index,
            'length': indexed['length'][mask],
            'name': indexed['name'][mask],
            'id': new_ids}

In [3]:
from unittest import TestCase
from skbio import TreeNode
import numpy.testing as npt
import numpy as np


class ShearTests(TestCase):
    def test_shear_identity(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'a', 'b', 'd', 'e'}
        obs = shear(tree, to_keep)
        npt.assert_equal(obs['length'], tree['length'])
        npt.assert_equal(obs['id'], tree['id'])
        npt.assert_equal(obs['name'], tree['name'])
        npt.assert_equal(obs['child_index'], tree['child_index'])

    def test_shear_drop_clade(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'d', 'e'}
        exp = {'length': np.array([4, 5, 6, np.nan]),
               'name': np.array(['d', 'e', 'f', 'root']),
               'id': np.array([0, 1, 2, 3]),
               'child_index': np.array([[2, 0, 1],
                                        [3, 2, 2]])}

        obs = shear(tree, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])

    def test_shear_complex_identity(self):
        tree = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                              '(((e:4)f:5)g:6)h:7,'
                              '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                              ')root;'])
        to_keep = {n.name for n in tree.tips()}
        exp = {'length': np.array([1, 2, 3, 4, 5, 6, 8, 9, 11, 10, 12, 13, 14,
                                   15, 4, 7, 16,  np.nan]),
               'id': np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
                               14, 15, 16, 17]),
               'child_index': np.array([[4, 3, 3],
                                        [5, 4, 4],
                                        [9, 6, 7],
                                        [10, 8, 8],
                                        [11, 9, 10],
                                        [12, 11, 11],
                                        [14, 0, 2],
                                        [15, 5, 5],
                                        [16, 12, 13],
                                        [17, 14, 16]]),
               'name': np.array(['a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'l',
                                 'k', 'm', 'n', 'o', 'p', 'd', 'h', 'q',
                                 'root'], dtype=object)}

        tree_array = tree.to_array()

        obs = shear(tree_array, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])

    def test_shear_complex(self):
        tree = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                              '(((e:4)f:5)g:6)h:7,'
                              '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                              ')root;'])
        to_keep = {'b', 'c', 'i', 'l', 'p'}
        exp = {'length': np.array([2, 3, 8, 11, 10, 12, 13, 14, 15, 4, 16,
                                   np.nan]),
               'id': np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
               'child_index': np.array([[4, 2, 2],  # k
                                        [5, 3, 3],  # m
                                        [6, 4, 5],  # n
                                        [7, 6, 6],  # o
                                        [9, 0, 1],  # d
                                        [10, 7, 8],  # q
                                        [11, 9, 10]]),  # root
               'name': np.array(['b', 'c', 'i', 'l', 'k', 'm', 'n', 'o', 'p',
                                 'd', 'q', 'root'], dtype=object)}

        tree_array = tree.to_array()

        obs = shear(tree_array, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])

In [13]:
# adapted from http://amodernstory.com/2015/06/28/running-unittests-in-the-ipython-notebook/
from unittest import TestLoader, TextTestRunner
suite = TestLoader().loadTestsFromModule(ShearTests())
TextTestRunner().run(suite)

....
----------------------------------------------------------------------
Ran 4 tests in 0.016s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>