In [1]:
import ete3 as et # phylogenetics library that provides tree data structure with nice methods for inspection and manipulation
import json # necessary to read in exported nextstrain JSON file
from pprint import pprint

# Load and parse the tree

In [2]:
def make_node(data_dict):
    '''
    data_dict: portion of the nextstrain tree JSON representing either an internal node or leaf
    returns: ete3 TreeNode object with all attributes and relationships populated from the nextstrain JSON
    '''
    node = et.TreeNode()
    node.name = data_dict['name'] # copy name
    node.dist = 0.    # initialize for the root node
    
    def add_features(attrs, node):
        '''
        attrs: dictionary of either `branch_attrs` or `node_attrs` from the nextstrain json
        node: node to attach these two
        side effect: modifies node in place to attach these attributes as `node.features` per ETE3 convention.
                     these can later be referenced via `node.attribute_name => value`
        '''
        attrs = { 
            k: v['value'] if # deal with nextstrain json format quirk with occasional nested dictionaries for single-value attributes
                isinstance(v, dict) and 'value' in v
                else v
                for k,v in attrs.items()
                }
        node.add_features(**attrs)
        
    if 'branch_attrs' in data_dict:
        add_features(data_dict['branch_attrs'], node)        
        
    if 'node_attrs' in data_dict:
        cumulative_div = data_dict['node_attrs'].pop('div') # pop out the divergence value to store as `node.dist`, rather than as a metadata feature
        add_features(data_dict['node_attrs'], node)

    
    if 'children' in data_dict and len(data_dict['children']) > 0: # recursively visit each of the children of the node to build out the full tree structure with the right parental relationships
        for c in data_dict['children']:
            if 'node_attrs' in c and 'div' in c['node_attrs']:
                d = c['node_attrs']['div']
            else:
                d = 0.
            node.add_child(make_node(c), 
                           dist = d - cumulative_div) # nextstrain records each node/leaf's *total distance* from the root, not the incremental distance from its parent
    return node
        
def make_tree(json_file):
    '''
    json_file: filehandle to read or pre-parsed JSON file as a dictionary
    returns: ete3.Tree object. NB that ete3.Tree and ete3.TreeNode objects are synonymous; see their docs for more
    '''
    if type(json_file) == str:
        json_dict = json.load(open(json_file, 'r'))
    else:
        assert(type(json_file)==dict)
        json_dict = json_file
        
    tree = make_node(json_dict['tree'])
    return tree

In [3]:
tree_file = './ncov_humboldt.json'
tree = make_tree(tree_file)
tree.ladderize() # sort / rotate nodes for a tidy looking tree
tree.describe() # basic facts to check we did it right

Number of leaf nodes:	2653
Total number of nodes:	5046
Rooted:	No
Most distant node:	Italy/CAM-AMES-30-68/2021
Max. distance:	41.000000


# Define and inspect the clade (unit of analysis)

In [7]:
def get_clade_mrca(targets, tree):
    '''
    targets: a single internal node that represents the most recent common ancestor of the clade of interest OR
           an iterable containing multiple leaves/samples of interest
    tree: root node of the tree that contains `nodes`
    returns: single internal node that represents the most recent common ancestor of the clade of interest 
    (i.e., the smallest subtree that still contains all of the leaves/samples of interest).
    
    N.B.: ETE toolkit does not distinguish between a node and a tree, as any node with children can be considered 
    a subtree. 
    '''
    if isinstance(targets, et.TreeNode): 
        assert(len(targets.children) > 0)
        return targets
    else:
        samples = [tree.get_nodes_by_name(s) if type(s) == str else s for s in targets]
        return tree.get_common_ancestor(samples)

    
def samples_form_monophyletic_clade(samples, mrca=None):
    '''
    samples: list of leaf nodes of interest that are in the clade
    returns: boolean indicator of whether these samples form a monophyletic clade 
            (i.e., all more closely related to each other than to anything else)
    '''
    if mrca is None:
        mrca = get_clade_mrca(samples) # find the smallest subtree that contains all of the samples
    return(len(mrca) == len(samples)) # return whether there are other samples that are also part of the clade

def get_parent(node, min_parent_muts=None):
    '''
    node: internal node, should be the mrca of the clade of interest
    min_muts (optional): optionally, find the first parent that is at least min_muts away from `node` 
                        (may be the great-/grandparent, etc.)
    '''
    parent = node.up
    
    if min_parent_muts:
        while node.get_dist(parent) < min_parent_muts:
            parent = parent.up
        
    return parent

def get_niblings(node, min_parent_muts=None):
    '''
    node: node representing the most recent common ancestor of the clade of interest
    min_muts: passed as parameter to `find_parent` used to define siblings
    returns: all leaves that descend from the `node`s siblings
    '''
    parent = find_parent(node, min_parent_muts)
    niblings = [s for s in node.up if s not in node]
    return niblings

# Contextualize the clade

In [8]:
def get_tmrca(mrca):
    '''
    node: node representing the most recent common ancestor of the clade of interest
    returns: decimal date
    '''
    return getattr(node, 'num_date')

def clade_is_monophyletic_wrt_attr(clade, attr = 'location'):
    '''
    clade: root node of (sub)tree being investigated
    attr: attribute to test monophyly with respect to
    returns: boolean indicator of whether all samples in the clade have the same non-null value for `attr`. 
    '''
    
    values_seen = set()
    for sample in node.iter_leaves():
        if attr not in s.features:
            return False
        
        value = getattr(sample, attr)
        values_seen.add(value)
        if len(values_seen) > 1 or value is None:
            return False
        
    return True

def get_subclades(node, attr, attr_values):
    '''
    node: mrca of subtree to search
    attr: attribute to search by
    attr_values: valid attr values for a subtree to be included 
    returns: list of dictionaries, each describing a subclade
    '''
    
    def get_first_leaf_attr_value(node, attr):
        '''wrapping an iterator to avoid traversing the same subtree a gazillion times just to get one value'''
        for leaf in node.iter_leaves(): # grab just the first leaf; use an iterator anyway for efficiency 
            return getattr(leaf, attr)
    
    def is_subclade(node, attr, attr_values):
        '''split off a node if either:
        (1) it is a single leaf with a valid attr value, 
            meaning none of its parents were monophyletic and it is a one-off
        (2) it is a monophyletic subclade where all descendent leaves have the same valid attr value
        '''
        if get_first_leaf_attr_value(node, attr) not in attr_values: # is the first (or only) value a valid value?
            return False        
        if node.is_leaf(): # a leaf can be its own subclade
            return True
        return clade_is_monophyletic_wrt_attr(node, attr) # if >1 leaves, do they all have the same value as the first value?
    
    subclades = [sc for sc in node.iter_leaves(is_leaf_fn=is_subclade)] # built in tree traversal function that tests each node and returns it as a subtree if the `is_leaf_fn` returns True
    
    return [{'n_samples': len(sc), attr: get_first_leaf_attr_value(sc, attr), 'subtree': sc} for sc in subclades]
        
    
    
def min_transmissions_across_demes(node, niblings, attr_value, attr = 'location'):
    '''
    node: mrca of the clade of interest
    niblings: niblings that "flank" the clade of interest
    attr_value: Ignore subclades with this value. Usually the "home location" for the clade of interest. 
    attr: usually 'location' or another geographic descriptor
    returns: list of dictionaries, each describing a putative introduction into the "home location."
    '''
    nibling_values = set([getattr(n, attr) for n in niblings])
    if attr_value in nibling_values:
        set.remove(attr_value)
    return get_subclades(clade, attr, nibling_values)
    
def clade_uniqueness(node):
    '''
    node: node representing the most recent common ancestor of the clade of interest
    returns: N mutations between node and its immediate parent
    '''
    parent = get_parent(node)
    return node.get_distance(parent)
    

def n_onward_with_accumulated_muts(node, min_muts, min_nodes_between=None):
    '''
    node (required): internal node, should be the mrca of the clade of interest
    min_muts (required): minimum number of mutations (branch length) between the internal node and each sample
    min_nodes_between (optional): number of internal nodes along the path between `node` and each sample
    returns: N descendent samples that satisfy requirements 
            (if min_nodes_bewteen is specified, must satisfy *both* requirements
    '''
    
    descendents_with_accumulated_muts = []
    for s in node.iter_descendents():
        n_muts = node.get_distance(s) # summed branch length along path between node and leaf
        
        if muts < min_muts: # not enough mutations
            continue
        elif not min_nodes_between:
            descendents_with_accumulated_muts.append(s) # enough mutations and no topology requirement
        else: 
            n_nodes_between = node.get_distance(s, topology_only = True) 
            if n_nodes_between >= min_nodes_between: # enough mutations and enough nodes in between
                descendents_with_accumulated_muts.append(s)
        
    return len(descendents_with_accumulated_muts)

In [None]:
def describe_my_samples(samples, tree, monophyly_attr_value, 
                        monophyly_attr = 'location',
                        min_parent_muts = None, 
                        min_onward_muts = 3,
                        min_onward_nodes_bwn = None):
    
    '''
    samples: list of samples of interest
    tree: json or pre-parsed et.TreeNode object
    monophyly_attr: what attribute to use when assessing if clade is monophyletic (usually location or other geographic designation)
    min_parent_muts: HEURISTIC how far back in the tree to look for "niblings"
    min_onward_muts: HEURISTIC threshold for N mutations req'd to consider a sample the likely result of onward transmission
    min_onward_nodes_bwn: HEURISTIC threshold for N intermediate nodes req'd to consider a sample the likely result of onward transmission
    '''
    
    if type(tree) == str:
        tree = make_tree(tree)
    assert(isinstance(tree, et.TreeNode) and all([s in tree for s in samples]))
    
    data = {'samples': [s.name for s in samples]}
    
    mrca = get_clade_mrca(targets, tree)
    data['tmrca'] = get_tmrca(mrca)
    data['samples_monophyletic'] = samples_form_monophyletic_clade(samples, mrca)
    data['niblings'] = get_niblings(mrca, min_parent_muts)
    data['geo_monophyletic'] = clade_is_monophyletic_wrt_attr(mrca, monophyly_attr)
    data['min_transmissions_across_demes'] = [] if data['geno_monophyletic'] 
                                            else min_transmissions_across_demes(mrca, monophyly_attr_value, monophyly_attr)
    
    data['muts_from_parent'] = clade_uniqueness(mrca)
    data['n_onward_with_accumulated_muts'] = n_descendents_with_accumulated_muts(mrca,
                                                                                min_onward_muts, 
                                                                                min_onward_nodes_bwn)
    return data
    

# Describe the dataset