In [105]:
import numpy as np
from collections import defaultdict
import pandas as pd
import treeswift as ts
import json
import re

In [125]:
def compute_weighted_unifrac(tree_obj, true_labels, final_labels):
	true_abunds = {}
	final_abunds = {}

	u = 0
	D = 0
	for n in tree_obj.traverse_postorder():
		if n.is_root():
			break
		if n.is_leaf():
			true_abunds[n.label] = 0
			final_abunds[n.label] = 0
			if n.label in true_labels:
				true_abunds[n.label] = true_labels[n.label]
			if n.label in final_labels:
				final_abunds[n.label] = final_labels[n.label]	
		else:
			true_abunds[n.label] = 0
			final_abunds[n.label] = 0
			for c in n.child_nodes():
				true_abunds[n.label] += true_abunds[c.label]
				final_abunds[n.label] += final_abunds[c.label]
		u += n.edge_length * np.fabs(true_abunds[n.label] - final_abunds[n.label])
		D += n.edge_length * (true_abunds[n.label] + final_abunds[n.label])
	# print(u)
	return u/D

def compute_unifrac(tree_obj, true_labels, final_labels):
	true_abunds = {}
	final_abunds = {}

	u = 0
	D = 0
	for n in tree_obj.traverse_postorder():
		if n.is_root():
			break
		if n.is_leaf():
			true_abunds[n.label] = False
			final_abunds[n.label] = False
			if n.label in true_labels:
				true_abunds[n.label] = True
			if n.label in final_labels:
				final_abunds[n.label] = True
				
		else:
			true_abunds[n.label] = False
			final_abunds[n.label] = False
			for c in n.child_nodes():
				if true_abunds[c.label]:
					true_abunds[n.label] = True
				if final_abunds[c.label]:
					final_abunds[n.label] = True

		if true_abunds[n.label] and not final_abunds[n.label]:
			u += n.edge_length
		if not true_abunds[n.label] and final_abunds[n.label]:
			u += n.edge_length
		D += n.edge_length
	return u/D

def distance_between(u, v):
    """Return the distance between nodes ``u`` and ``v`` in this ``Tree``
    Args:
    ``u`` (``Node``): Node ``u``
    ``v`` (``Node``): Node ``v``
    Returns:
    ``float``: The distance between nodes ``u`` and ``v``
    """
    # print(u.get_label(), v.get_label())
    if u == v:
        return 0.0
    elif u == v.parent:
        return v.edge_length
    elif v == u.parent:
        return u.edge_length
    u_dists = {u: 0.0}
    v_dists = {v: 0.0}
    c = u
    p = u.parent  # u traversal
    while p is not None:
        u_dists[p] = u_dists[c]
        if c.edge_length is not None:
            u_dists[p] += c.edge_length
        c = p
        p = p.parent
    if v in u_dists:
        return u_dists[v]
    c = v
    p = v.parent  # v traversal
    while p is not None:
        v_dists[p] = v_dists[c]
        if c.edge_length is not None:
            v_dists[p] += c.edge_length
        if p in u_dists:
            return u_dists[p] + v_dists[p]
        c = p
        p = p.parent

In [155]:
df_16S = pd.read_csv("./data_16S.tsv", sep="\t", index_col="id")
df_16S = df_16S.div(df_16S.sum(axis=1), axis=0)
df_16S.head()

Unnamed: 0_level_0,1b158b8b2922d4fcad5d9cea607cbb7d,c1dc9ad5116d96b8ed863458fc0d0aec,6a6fcf8f9b8bb1ab9e5f8456ee7fb109,aed3f59201e3b9d21858f36557f42a80,f95cab37fba4160de15015f4d520839f,cda4e6f933bb3108ea3e92f9db411c00,afd87e82de329a1ed75b98b5b606843c,668fdb718997fc1589c7817655d4bb5f,51e441cbdcc80da0656e82293ae160b5,ee293984c0110b2eeceb8427fdf448fb,...,13a37c86712739ee091c0fbc2a13e1cb,d94a05561f7643eeb4a75f59435df2df,6b1597644d7942d2276e0facc93bc63b,ac0f7d14d0cc7d87687a7a72172dba00,e803ff46adaa0fa149ef151b082378a0,a11e3437a79a85d83c4bcdd1acd6ff27,c6adabe4621cc8ab924f25db04dec07c,e8120a2c1c4a6888ab33d17a033a6a56,7c02b294ddbf1c8eb03ac162df74780e,e48e37a0467ce3d2ff2c3a0f1167497b
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
50076,0.8456,0.059918,0.026361,0.010328,0.006898,0.004773,0.004027,0.003729,0.003207,0.002498,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
510,0.00176,0.233586,0.005046,0.0,0.0,0.000117,0.000235,0.000645,0.000235,0.003051,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
80129,0.029748,0.043962,0.011572,0.026226,0.0,0.0,0.012956,0.000566,0.0,0.005912,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50395,0.106462,0.178344,0.000666,0.000848,0.01623,0.000182,0.01629,0.00212,0.000606,0.047236,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
156,0.448909,0.108317,0.151102,0.009508,0.001329,0.001431,0.000358,0.0,0.0,0.001789,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
df_metadata = pd.read_csv("./metadata.tsv", sep="\t")
df_metadata.head()

Unnamed: 0,#SampleID,sample,group
0,3,S3,TD
1,10,S10,TD
2,45,S45,TD
3,76,S76,TD
4,78,S78,TD


In [10]:
with open("./depp_placements.jplace", 'r') as f:
    jplace = json.load(f)

dict_keys(['fields', 'metadata', 'placements', 'tree', 'version'])

In [84]:
tree_wp = ts.read_tree_newick(re.sub(r'{[0-9]+}', '', jplace["tree"]))
ic = 0
for node in tree_wp.traverse_postorder(leaves=True, internal=True):
    # if node.get_label() == None or not(node.get_label().startswith("G")):
    node.set_label(f"N{ic}")
    ic+=1
tree_wp.write_tree_newick("./tree_backbone.nwk")

label_to_node = tree_wp.label_to_node(selection='all')

for placement in jplace["placements"]:
    edge_num, likelihood, like_weight_ratio, distal_length, pendant_length = placement["p"][0]
    child = label_to_node[f"N{edge_num}"]
    if (distal_length > child.get_edge_length()):
        while (distal_length > child.get_edge_length()):
            distal_length -=  child.get_edge_length()
            child = child.get_parent()
    parent = child.get_parent()
    placed = ts.Node(label=placement["n"][0], edge_length=pendant_length)
    inner = ts.Node(label=placement["n"][0]+"i", edge_length=distal_length)
    placed.set_parent(inner)
    inner.set_parent(parent)
    parent.remove_child(child)
    parent.add_child(inner)
    inner.add_child(child)
    inner.add_child(placed)
    child.set_parent(inner)
    child.set_edge_length(child.get_edge_length() - distal_length)

In [83]:
tree_wp.write_tree_newick("./tree_wplacements.nwk")

In [107]:
wunifrac_pairs = defaultdict(dict)
unifrac_pairs = defaultdict(dict)
for index1, row1 in df_16S.iterrows():
    for index2, row2 in df_16S.iterrows():
        wunifrac = compute_weighted_unifrac(tree_wp, row2.to_dict(), row1.to_dict())
        unifrac = compute_unifrac(tree_wp, row2[row2!=0].to_dict(), row1[row1!=0].to_dict())
        wunifrac_pairs[index1][index2] = wunifrac
        unifrac_pairs[index1][index2] = unifrac

In [109]:
df_wunifrac = pd.DataFrame(wunifrac_pairs)
df_unifrac = pd.DataFrame(unifrac_pairs)

In [149]:
df_wunifrac.to_csv("./distance_matrix-wunifrac-depp_only.tsv", sep="\t")
df_unifrac.to_csv("./distance_matrix-unifrac-depp_only.tsv", sep="\t")

In [130]:
label_to_node = tree_wp.label_to_node(selection='leaves')
reference_distances = defaultdict(dict)
backbone_references = [val for key, val in label_to_node.items() if key.startswith("N")]
for placed in df_16S.columns:
    reference_distances[placed] = dict(list(map(lambda x: (x, distance_between(label_to_node[placed], x)), backbone_references)))

1b158b8b2922d4fcad5d9cea607cbb7d
c1dc9ad5116d96b8ed863458fc0d0aec
6a6fcf8f9b8bb1ab9e5f8456ee7fb109
aed3f59201e3b9d21858f36557f42a80
f95cab37fba4160de15015f4d520839f
cda4e6f933bb3108ea3e92f9db411c00
afd87e82de329a1ed75b98b5b606843c
668fdb718997fc1589c7817655d4bb5f
51e441cbdcc80da0656e82293ae160b5
ee293984c0110b2eeceb8427fdf448fb
73bf8d1a5983e34a0cb84e3cae127815
dc2721103659fe9f1d3ead56a11df243
90a05d597112b554e4480a8eaae4e0aa
cc2d96099f530b503371e5ddca8c0a58
45a68a9eee3cf83e27f4ea309d57ffc3
a5189f77a2cfeab3bc1602ff5c8ac3e9
e553b9a0bb32467c71c89a4e97e55792
b20c095fd654b84cebdbfe4faa0a1409
8114b1d0274e9e4bb6c91f6af1b8fac8
8347bd34436f72573fcde614b95d4702
d5c7d97e6f4f5789d574d321dcca0992
24a60c6448e70d9198ad6ba93520958c
974f6d026e2fa28ed37bcc0308ed56cf
bd2ebc70501f7d867c204f94c4e483da
b3bd0d387d67ca01fe3197f5bf66b032
5d6ee23084c6b9c96deb9a83295abc8a
e59405a47acbc248ce61395366159d8d
a3bf9252a3063e2844dcade5e4192e50
9f8668eb1c5f9d9a992dd49245db090e
e083e2f58987c5f8db5d4dd16ddde91f
b65eb19257

In [137]:
mixture_distances = defaultdict(lambda: defaultdict(float))
for sample, row in df_16S.iterrows():
    for key_placed, p in row.to_dict().items():
        for key_ref, d in reference_distances[key_placed].items():
            mixture_distances[sample][key_ref] += p*d

In [143]:
df_mixture = pd.DataFrame(mixture_distances)
df_mixture.to_csv("mixture_distances-all_references.tsv", sep="\t")