# Calculating triplet and quartet scores

Re-implemented based on Uyen Mai's [code](https://github.com/uym2/10kBacGenomes/blob/master/pergroup_stats.py).

In [51]:
import pandas as pd
from skbio import TreeNode

In [110]:
tree = TreeNode.read('tree.nwk')

In [111]:
df = pd.read_table('rank_names.tsv', index_col=0)

In [112]:
df = df.loc[[x.name for x in tree.tips()]]

In [113]:
%%time
res = None
for rank in df.columns:
    _df = pd.DataFrame(df[rank].dropna().value_counts().reset_index())
    _df.columns = ['taxon', 'count']
    _df = _df.query('1 < count < %d' % df.shape[0])
    _df['rank'] = rank
    _df['triplet'], _df['quartet'] = zip(*_df.apply(lambda row: scores(
        set(df.index[df[rank] == row['taxon']].tolist()), tree), axis=1))
    res = res.append(_df) if res is not None else _df

CPU times: user 1min 30s, sys: 451 ms, total: 1min 31s
Wall time: 1min 31s


In [114]:
res.head()

Unnamed: 0,taxon,count,rank,triplet,quartet
0,Bacteria,9906,kingdom,1.0,1.0
1,Archaea,669,kingdom,1.0,1.0
0,Proteobacteria,2975,phylum,0.9933,0.989164
1,Firmicutes,1948,phylum,0.993025,0.986553
2,Actinobacteria,1097,phylum,0.990621,0.991678


In [115]:
res.to_csv('test.out.3', sep='\t')

In [109]:
def scores(taxa, tree):
    """Calculate triplet and quartet scores of a group of taxa in a tree.
    
    Parameters
    ----------
    taxa : set of str
        group of taxa (tip names)
    tree : skbio.TreeNode
        reference tree

    Returns
    -------
    tuple of (float, float)
        triplet and quartet scores
    """
    
    # n: total number of taxa in a clade
    n = tree.count(tips=True)

    # p: number of taxa in a clade that belong to input group
    p = len(taxa)
    
    # np: number of taxa in a clade that do NOT belong to input group
    np = n - p
    
    # limit search to lowest common ancestor to save computation
    lca = tree.lca(taxa).copy()
    
    # total numbers of triplets and quartets
    ntriplets = np * p * (p - 1) / 2
    nquartets = np * (np - 1) * p * (p - 1) / 4

    # numbers of triplets and quartets that fit the input group
    triplets = 0
    quartets = 0

    # iterate the tree in post order (from tip to stem)
    for node in lca.postorder():
        if node.is_tip():
            # add a new metadatum 'c' to the node, which is the number of
            # descending taxa that belong to input group
            node.c = 1 if node.name in taxa else 0
        else:
            node.c = 0
            counts = []
            for child in node.children:
                node.c += child.c
                counts.append(child.c)
            NP = (n - node.count(tips=True)) - (p - node.c)
            P = 0
            for i in range(len(counts) - 1):
                for j in range(i + 1, len(counts)):
                    P += counts[i] * counts[j]
            triplets += NP * P
            quartets += P * NP * (NP - 1) / 2

    # iterate the tree again
    for node in lca.postorder(include_self=False):
        if not node.is_tip():
            NP = node.count(tips=True) - node.c
            P = 0
            counts = [sibling.c for sibling in node.siblings()]
            counts.append(p - node.parent.c)
            for i in range(len(counts) - 1):
                for j in range(i + 1, len(counts)):
                    P += counts[i] * counts[j]
            quartets += P * NP * (NP - 1) / 2

    return triplets / ntriplets, quartets / nquartets