In [None]:
# Tree mean


In [2]:
import ete3
import networkx as nx

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [4]:
# Get a basic tree

tree = ete3.Tree("((A:0.1, B:0.2):0.3, C:0.4);", format=1)

print(tree)



      /-A
   /-|
--|   \-B
  |
   \-C


In [12]:
def weighted_sum_leaves(
    node: ete3.TreeNode, tree: ete3.Tree, leaf_values: dict
) -> float:
    """Return sum_{leaves} (distance to leaf) * leaf_value"""
    leaves = tree.get_leaves()
    for leaf in leaves:
        if leaf.name not in leaf_values:
            leaf_values[leaf.name] = 0

    dists = [node.get_distance(leaf) for leaf in leaves]
    weights = np.array([leaf_values[leaf.name] for leaf in leaves])
    return dists @ weights


def find_tree_mean(tree: ete3.Tree, leaf_values: dict) -> ete3.TreeNode:
    """
    Find the node that minimizes the sum of the squared distances to the leaves
    """

    # Start at root
    node = tree.get_tree_root()
    weighted_sum = np.inf

    # Compute the weighted sum of distances to leaves
    while not node.is_leaf():
        print(node.name, )
        candidates = [
            weighted_sum_leaves(child, tree, leaf_values)
            for child in node.children
        ]
        min_idx = np.argmin(candidates)
        min_weighted_sum = candidates[min_idx]

        if min_weighted_sum < weighted_sum:
            node = node.children[min_idx]
            weighted_sum = min_weighted_sum
        else:
            break

    return node


find_tree_mean(tree, {"A": 1, "B": 2, "C": 3})






Tree node '' (0x7f0fb89c2b9)

In [14]:
# Some basic sanity checks:

# Minimizer of a single leaf is the leaf
sc_1 = find_tree_mean(tree, {"A": 1})
print(sc_1)

# Minimizer of a balanced tree is the root
tree_sc_2 = ete3.Tree("((A:0.1, B:0.1):0.1, (C:0.1, D:0.1):0.1);", format=1)
sc_2 = find_tree_mean(tree_sc_2, {"A": 1, "B": 1, "C": 1, "D": 1})
print(sc_2)

# Branch length balances out leaf value





--A



   /-A
--|
   \-B


In [16]:
print(sc_2)


   /-A
--|
   \-B
