In [None]:
from Bio import Phylo
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import os
import statistics as stat
import pandas as pd
import seaborn as sns
import random
import copy

In [None]:
mpl.rcParams['font.family']       = 'Helvetica'
mpl.rcParams['font.sans-serif']   = ["Helvetica","Arial","DejaVu Sans","Lucida Grande","Verdana"]
mpl.rcParams['figure.figsize']    = [4,3]
mpl.rcParams['font.size']         = 9
mpl.rcParams["axes.labelcolor"]   = "#000000"
mpl.rcParams["axes.linewidth"]    = 1.0 
mpl.rcParams["xtick.major.width"] = 1.0
mpl.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab10
cmap2 = plt.cm.Set3  
colors1 = [cmap1(i) for i in range(0,10)]
colors2 = [cmap2(i) for i in range(0,12)] 
plt.style.use('default')

In [None]:
os.chdir("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/")
try:
    os.mkdir('figures')
except:
    None
try:
    os.mkdir('table')
except:
    None

In [None]:
# merge 2 GTDB trees (Bacteria and Archaea)
ar_tree  = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/gtdb/ar122_r95.tree", 'newick')
bac_tree = Phylo.read('/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/gtdb/bac120_r95.tree','newick')
merged_tree = Phylo.BaseTree.Tree()
merged_tree.clade.clades = [ar_tree.clade, bac_tree.clade]
Phylo.write(merged_tree, "gtdb/ar_bac_tree.nwk", 'newick')

In [None]:
WD = "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/"
reftree_file1 = "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/silva/LTPs132_SSU_tree.dereplicated.newick"
reftree_file2 = "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/gtdb/ar_bac_tree.ext.nwk"

In [None]:
####### parameters #######
mpl.rcParams['font.family']       = 'Helvetica'
mpl.rcParams['font.sans-serif']   = ["Helvetica","Arial","DejaVu Sans","Lucida Grande","Verdana"]
mpl.rcParams['figure.figsize']    = [4,3]
mpl.rcParams['font.size']         = 9
mpl.rcParams["axes.labelcolor"]   = "#000000"
mpl.rcParams["axes.linewidth"]    = 1.0 
mpl.rcParams["xtick.major.width"] = 1.0
mpl.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab10
cmap2 = plt.cm.Set3  
colors1 = [cmap1(i) for i in range(0,10)]
colors2 = [cmap2(i) for i in range(0,12)] 
plt.style.use('default')
##########################

In [None]:
profile = []
for r in np.arange(0.00, 4.01, 0.01):
    newick_path = "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/sigma_r_gamma2/sigma_r_gamma2_"+'{:.2f}'.format(r)+"/PRESUMEout/PRESUMEout.nwk"
    try:
        tree = Phylo.read(newick_path, 'newick')
        profile.append([newick_path, len(tree.get_terminals())])
    except:
        None
df_Ntip = pd.DataFrame(profile, columns = ['path', 'Ntips'])

In [None]:
df_Ntip

In [None]:
def get_BBIs(treefile):
    try:
        tree=Phylo.read(treefile, 'newick')
    except:
        return []

    name2ntips = {}
    for i, node in enumerate(tree.get_terminals()):
        name2ntips[node.name] = 1
    for i, node in enumerate(reversed(tree.get_nonterminals())):
        node.name = "clade"+str(i)
        name2ntips[node.name] = name2ntips[node.clades[0].name] + name2ntips[node.clades[1].name]

    BBIs=[]
    for node in tree.get_nonterminals():
        if(name2ntips[node.name]>=10 and len(node.clades)==2):
            # and name2ntips[node.clades[0].name]>=3 and name2ntips[node.clades[1].name]>=3)
            M = name2ntips[node.clades[0].name]
            N = name2ntips[node.clades[1].name]
            BBIs.append(min(M,N) / max(M,N))
    
    return BBIs
def HistogramIntersection(data1, data2, nbins=20, histrange=[0,1]):

    hist1, bins1 = np.histogram(
        data1, 
        bins=nbins, 
        range=histrange
        )
    hist1 = hist1 / sum(hist1)

    hist2, bins2 = np.histogram(
        data2, 
        bins=nbins, 
        range=histrange
        )
    hist2 = hist2 / sum(hist2)

    union, intersection = 0, 0
    for i in range(nbins):
        union        += max( hist1[i], hist2[i] )
        intersection += min( hist1[i], hist2[i] )
    
    return float(intersection / union)

In [None]:
# for various e
def analyze(parameter_label, reftree_file):
    refBBIs = get_BBIs(reftree_file)
    datadir = WD+parameter_label
    dir_list = os.listdir(WD+parameter_label)
    param_medBBI = []
    for dirname in dir_list:
        parameter     = float(dirname.split("_")[-1].split("CV")[0])
        newick_path = datadir+'/'+dirname+"/PRESUMEout/PRESUMEout.nwk"
        if os.path.exists(newick_path):
            BBIs             = get_BBIs(newick_path)
            if (len(BBIs) > 0):
                hist_intersection = HistogramIntersection(BBIs, refBBIs, nbins=20)
                param_medBBI.append([parameter, stat.median(BBIs),hist_intersection])
    param_medBBI = np.array(param_medBBI)
    
    #print(param_medBBI)

    columns = [parameter_label, 'medBBIs', 'hist_intersection']
    df = pd.DataFrame(data=param_medBBI, columns=columns)
    return df

In [None]:
df_vs_silva = analyze('sigma_r_gamma2', reftree_file1)
df_vs_gtdb  = analyze('sigma_r_gamma2', reftree_file2)
df_vs_silva.to_csv("table/vs_silva.csv")
df_vs_gtdb.to_csv("table/vs_gtdb.csv")


In [None]:
df_vs_silva = pd.read_csv("table/vs_silva.csv")
df_vs_gtdb  = pd.read_csv("table/vs_gtdb.csv")

parameter_label = "sigma_r_gamma2"
fig = plt.figure(figsize=(2.2,2.2))
ax1 = fig.add_axes([0.1,0.1,0.8,0.8])
ax2 = fig.add_axes([1.3,0.1,0.8,0.8])
sns.regplot(parameter_label, 'medBBIs',data=df_vs_silva,ax=ax1,fit_reg=False,scatter_kws={'s':5})
sns.regplot(parameter_label, 'hist_intersection',data=df_vs_silva,ax=ax2,fit_reg=False,scatter_kws={'s':5},color="#F3A83B")
sns.regplot(parameter_label, 'hist_intersection',data=df_vs_gtdb ,ax=ax2,fit_reg=False,scatter_kws={'s':5},color="#EB3223")

for ax in [ax1, ax2]:
    #ax.set_xlim(0,2.8)
    ax.set_ylim(0,1.0)
    ax.set_xlabel("$\sigma$",fontsize=10)
    #plt.show()
    #ax.set_xlim(0.01,max(param_medBBI[:,0]))
    #ax.set_xscale('log')
ax1.set_ylabel('Median branch balance index',fontsize=10)
ax2.set_ylabel('Histogram intersection',fontsize=10)
plt.savefig("figures/NK_0147_fitting.pdf", bbox_inches='tight')
plt.show()
plt.close()

In [None]:
df_vs_silva.sort_values('hist_intersection')

In [None]:
df_vs_gtdb.sort_values('hist_intersection')

In [None]:
def BBI_plot(tree_file_list):
    fig = plt.figure(figsize=(2.2,2.2))
    ax = fig.add_axes([0.1,0.1,0.8,0.8])
    #plt.style.use('dark_background')
    colors = ["#F3A83B","#EB3223","#4293F7"]
    widths = [1, 1, 3]
    histtypes = ["step", "step", "step"]
    alphas = [1, 1, 0.5]
    for tree_file, color, width, histtype,alpha in zip(tree_file_list, colors, widths, histtypes,alphas):
        BBIs = get_BBIs(tree_file)
        weights = np.ones(len(BBIs))/len(BBIs)
        ax.hist(BBIs, bins=25, histtype=histtype,lw=width,range=(0,1),color = color,alpha = alpha, weights = weights)
    
    title = tree_file.split("/")[-1]
    #ax.set_title(title)
    ax.set_xlim(0,1.0)
    ax.set_xlabel('Balance index',fontsize=10)
    ax.set_ylabel('Frequency',fontsize=10)
    ax.set_xticks([0,0.25,0.5,0.75,1])
    plt.savefig('figures/NK_0147_bbi_histogram.pdf',bbox_inches='tight')
    #plt.show()
    plt.close()

In [None]:
BBI_plot(
    [
        "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/silva/LTPs132_SSU_tree.dereplicated.newick",
        "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/gtdb/ar_bac_tree.ext.nwk",
        "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/sigma_r_gamma2/sigma_r_gamma2_1.96/PRESUMEout/PRESUMEout.nwk"
    ]
)

In [None]:
# branch length fitting
reftree = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/RAxML_result.prameter_fitting_for_23S",'newick')


depth_length_list = []
stack = [(reftree.clade, 1)]
while len(stack) > 0:
    node, depth = stack.pop()
    depth_length_list.append([depth, node.branch_length, (len(node.clades)==0)])
    stack.extend((child, depth+1) for child in node.clades)
df_depth_length = pd.DataFrame(np.array(depth_length_list), columns=['depth', 'length', 'is_tip'])
df_depth_length

In [None]:
fig = plt.figure(figsize=(2.2,2.2))
ax = fig.add_axes([0.1,0.1,0.8,0.8])
ax.scatter(df_depth_length['depth'], df_depth_length['length'],alpha=0.01)
ax.set_yscale('log')
ax.set_ylim(0.0001,1)
ax.set_xlabel('Depth')
ax.set_ylabel('Branch length')
#plt.show()
plt.savefig('figures/NK_0147_depth_length.pdf',bbox_inches='tight')
plt.close()

In [None]:
# assign branch length
tree_topology_1M = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/PRESUMEout_combined.nwk", 'newick')

In [None]:
random.seed(10)
internal_length_list = [node.branch_length for node in reftree.get_nonterminals()]
tip_length_list      = [node.branch_length for node in reftree.get_terminals()]

tree_branch_assigned_1M = copy.deepcopy(tree_topology_1M)
stack = [tree_branch_assigned_1M.clade]
i=0
while len(stack) > 0:
    node = stack.pop()
    #node.name = "clade"+str(i)
    if (len(node.clades) == 0):
        node.branch_length = random.choice(tip_length_list)
    else:
        node.branch_length = random.choice(internal_length_list)
    stack.extend(node.clades)
    i+=1
Phylo.write(tree_branch_assigned_1M, "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/PRESUMEout_combined_length_tip_or_internal.nwk", 'newick')

In [None]:
def return_df_length_pair(tree):
    name2ntips = {}
    for i, node in enumerate(tree.get_terminals()):
        name2ntips[node.name] = 1
    for i, node in enumerate(reversed(tree.get_nonterminals())):
        node.name = "clade"+str(i)
        name2ntips[node.name] = name2ntips[node.clades[0].name] + name2ntips[node.clades[1].name]

    length_pair_list = []
    for node in tree.get_nonterminals():
        child_L = node.clades[0]
        child_R = node.clades[1]
        if (name2ntips[child_L.name] < name2ntips[child_R.name]):
            length_pair_list.append([child_L.branch_length, child_R.branch_length])
        else:
            length_pair_list.append([child_R.branch_length, child_L.branch_length])

    df_length_pair = pd.DataFrame(length_pair_list, columns = ['Small subclade', 'Large subclade'])
    return df_length_pair

In [None]:
random.seed(10)

# make name2ntips for reftree
name2ntips_ref = {}
for i, node in enumerate(reftree.get_terminals()):
    name2ntips_ref[node.name] = 1
for i, node in enumerate(reversed(reftree.get_nonterminals())):
    node.name = "clade"+str(i)
    name2ntips_ref[node.name] = name2ntips_ref[node.clades[0].name] + name2ntips_ref[node.clades[1].name]

# list up pair of branch lengths
length_pair_list = []
for node in reftree.get_nonterminals():
    child_L = node.clades[0]
    child_R = node.clades[1]
    if (name2ntips_ref[child_L.name] < name2ntips_ref[child_R.name]):
        length_pair_list.append((child_L.branch_length, child_R.branch_length))
    else:
        length_pair_list.append((child_R.branch_length, child_L.branch_length))

# make name2ntips for 1M tree
tree_branch_assigned_1M = copy.deepcopy(tree_topology_1M)
name2ntips = {}
for i, node in enumerate(tree_branch_assigned_1M.get_terminals()):
    name2ntips[node.name] = 1
for i, node in enumerate(reversed(tree_branch_assigned_1M.get_nonterminals())):
    node.name = "clade"+str(i)
    name2ntips[node.name] = name2ntips[node.clades[0].name] + name2ntips[node.clades[1].name]

for node in tree_branch_assigned_1M.get_nonterminals():
    length_S, length_L = random.choice(length_pair_list)
    child_L = node.clades[0]
    child_R = node.clades[1]
    if (name2ntips[child_L.name] < name2ntips[child_R.name]):
        child_L.branch_length = length_S
        child_R.branch_length = length_L
    else:
        child_L.branch_length = length_L
        child_R.branch_length = length_S

Phylo.write(tree_branch_assigned_1M, "/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/PRESUMEout_combined_length_pair_sampling.nwk", 'newick')

In [None]:
reftree = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/RAxML_result.prameter_fitting_for_23S",'newick')
tree_branch_assigned_1M_tip_or_internal = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/PRESUMEout_combined_length_tip_or_internal.nwk",'newick')
tree_branch_assigned_1M_pair_sampling = Phylo.read("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/branch_length/PRESUMEout_combined_length_pair_sampling.nwk", 'newick')

for i, tree in enumerate([reftree, tree_branch_assigned_1M_tip_or_internal, tree_branch_assigned_1M_pair_sampling]):
    fig = plt.figure(figsize=(2.2,2.2))
    df_length_pair_ref_tree = return_df_length_pair(tree)
    ax = fig.add_axes([0.1,0.1,0.8,0.8])
    #sns.jointplot(x=df["Small subclade"], y=df['Large subclade'], kind='kde')
    ax.scatter(df_length_pair_ref_tree["Small subclade"], df_length_pair_ref_tree['Large subclade'],alpha=0.01)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(0.0001,1)
    ax.set_ylim(0.0001,1)
    x = np.arange(0, 10)
    y = x
    ax.plot(x, y, color = "blue")
    ax.set_xlabel('Branch length to\nsmall subclade')
    ax.set_ylabel('Branch length to\nlarge subclade')
    plt.show()
    plt.savefig('figures/NK_0147_small_large_'+str(i)+'.jpg',bbox_inches='tight')
    plt.close()

In [None]:
for i, tree in enumerate([reftree, tree_branch_assigned_1M_tip_or_internal, tree_branch_assigned_1M_pair_sampling]):
    fig = plt.figure(figsize=(4.4,2.2))
    ax = fig.add_axes([0.1,0.1,0.8,0.8])
    ax.hist(x=[node.branch_length for node in tree.get_terminals()],    label='tip',      range=(0,1.2), bins=120, alpha = 0.5)
    ax.hist(x=[node.branch_length for node in tree.get_nonterminals()], label='internal', range=(0,1.2), bins=120, alpha = 0.5)
    ax.set_xlabel('Branch length')
    ax.set_ylabel('Density')
    plt.savefig('figures/NK_0147_length_distribution'+str(i)+'.pdf',bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
df_HI_mu = pd.read_table("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/mutation_rate/HI.manual.txt", names=["ID", 'mu', 'HI'])
df_HI_mu.sort_values("HI")
fig = plt.figure(figsize=(2.2,2.2))
ax = fig.add_axes([0.1,0.1,0.8,0.8])
ax.scatter(df_HI_mu["mu"], df_HI_mu["HI"]*100, color = "#F3A83B", alpha = 1, s = 5)
colors = ["#F3A83B","#EB3223","#4293F7"]
ax.set_xlabel("$\mu$")
ax.set_ylabel("Agreement in\ndistribution of NHD (%)")
ax.set_xlim(0,0.5)
ax.set_ylim(0,100)
plt.savefig('figures/NK_0147_NHD_fitting.pdf',bbox_inches='tight')
plt.show()
plt.close()

In [None]:
sim_nhd = pd.read_table("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/mutation_rate/m_177.txt", names = ["nhd"])
nat_nhd = pd.read_table("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/mutation_rate/natural.txt", names = ["nhd"])

fig = plt.figure(figsize=(2.2,2.2))
ax = fig.add_axes([0.1,0.1,0.8,0.8])

nat_weights = np.ones(len(nat_nhd["nhd"])) / len(nat_nhd["nhd"])
ax.hist(nat_nhd["nhd"], color = "#F3A83B", alpha = 1, weights = nat_weights, range=(0,1),edgecolor="#F3A83B",histtype="step",lw=1, bins=50)

sim_weights = np.ones(len(sim_nhd["nhd"])) / len(sim_nhd["nhd"])
ax.hist(sim_nhd["nhd"], color = "#4293F7", alpha = 0.5, weights = sim_weights, range=(0,1),edgecolor="#4293F7",histtype="step",lw=3, bins=50)

colors = ["#F3A83B","#EB3223","#4293F7"]
ax.set_xlabel("NHD")
ax.set_ylabel("Frequency")
ax.set_xlim(0,1)
#ax.set_ylim(0,100)
plt.savefig('figures/NK_0147_NHD_hist.pdf',bbox_inches='tight')
plt.show()
plt.close()

In [None]:
df_subclades = pd.read_csv("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/result/result_subclades.all.csv", names = ["TASK_ID", "SEQ_ID", "Nseq", "method", "Threshold", "Memory", "Mem_unit", "RunTime", "Time_unit", "Ntips", "NRFD"])
df_subclades["Accuracy"] = (1-df_subclades["NRFD"])*100
df_subclades = df_subclades[df_subclades["RunTime"] < 259200]


df_replicates = pd.read_csv("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0147/result/result_replicates.csv", names = ["TASK_ID", "ID", "Nseq", "method", "Threshold", "SEQ_ID", "Memory", "Mem_unit", "RunTime", "Time_unit", "Ntips", "NRFD"])
df_replicates["Accuracy"] = (1-df_replicates["NRFD"])*100
df_replicates["Coverage"] = df_replicates["Ntips"] / df_replicates["Nseq"] * 100
df_replicates["index"]    = df_replicates["TASK_ID"] % 5
df_replicates["x"]        = (df_replicates["SEQ_ID"]-1) * 7 + df_replicates["index"]

In [None]:
for method in ["rapidnjNJ", "raxmlMP", "fasttreeML"]:
    df_subclades_ext = df_subclades[df_subclades["method"] == method]
    fig=plt.figure(figsize=(2.2,1.8))
    ax = fig.add_axes([0.1,0.1,0.8,0.8])
    df_subclades_ext_original = df_subclades_ext[df_subclades_ext["Threshold"] == 10000000]
    df_subclades_ext_fractal_1node     = df_subclades_ext[(df_subclades_ext["Threshold"] == 10000) & ~((df_subclades["TASK_ID"]<=6) & (df_subclades["SEQ_ID"]>=82))]
    df_subclades_ext_fractal_100nodes  = df_subclades_ext[((df_subclades["TASK_ID"]<=6) & (df_subclades["SEQ_ID"]>=82))]

    ax.scatter(x = df_subclades_ext_fractal_100nodes["Ntips"], y = df_subclades_ext_fractal_100nodes["Accuracy"], color = '#7F33FF', s = 50)

    ax.scatter(x = df_subclades_ext_fractal_1node   ["Ntips"], y = df_subclades_ext_fractal_1node   ["Accuracy"], color = "#88C9D4", s = 20)
    
    ax.scatter(x = df_subclades_ext_original["Ntips"], y = df_subclades_ext_original["Accuracy"], color = "#F8D686", s = 5)

    ax.set_xlim(1000,2000000)
    ax.set_xscale("log")
    ax.set_title(method)
    ax.set_ylim(-5,105)
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    ax.set_xlabel("Size of reconstructed tree")
    ax.set_ylabel("Accuracy (%)")
    plt.savefig('figures/NK_0147_subclades_'+method+'.pdf',bbox_inches='tight')
    ax.set_ylim(-5,105)
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    ax.set_xlabel("Size of reconstructed tree")
    ax.set_ylabel("Accuracy (%)")
    plt.savefig('figures/NK_0147_subclades_'+method+'.pdf',bbox_inches='tight')
    plt.close()

In [None]:
for method in ["rapidnjNJ", "raxmlMP", "fasttreeML"]:
    df_replicates_ext = df_replicates[df_replicates["method"] == method]
    fig=plt.figure(figsize=(1.8,1.8))
    ax = fig.add_axes([0.1,0.1,0.36,0.8])
    ax2 = fig.add_axes([0.8,0.1,0.36,0.8])

    ax.scatter(x = "x",y="Accuracy", data=df_replicates_ext, color='#7F33FF', alpha=1,s=2.5)
    ax.grid(b=False)
    ax.set_ylim(-5,105)
    ax.set_xlabel("Dataset")
    ax.set_ylabel("Accuracy")
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    ax.set_xticks([2, 9, 16, 23])
    ax.set_xticklabels([1,2,3,4])

    ax2.scatter(x = "x",y="Coverage", data=df_replicates_ext, color='#7F33FF', alpha=1,s=2.5)
    ax2.grid(b=False)
    ax2.set_ylim(-5,105)
    ax2.set_xlabel("Dataset")
    ax2.set_ylabel("Coverage")
    ax2.spines["top"].set_color("none")
    ax2.spines["right"].set_color("none")
    ax2.set_xticks([2, 9, 16, 23])
    ax2.set_xticklabels([1,2,3,4])
    plt.savefig('figures/NK_0147_replicates_'+method+'.pdf',bbox_inches='tight')
    #plt.close()