In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.backends.backend_pdf as backend_pdf
import matplotlib.pyplot as plt
import re
from numpy import genfromtxt
from ete3 import Tree
import seaborn as sns
import arviz as az
from plotting import plot_posterior

In [None]:
os.getcwd()

In [None]:
nnodes = 9
tree = Tree('../../data/chazot_subtree.nw') 


In [None]:
leafidx = []
inneridx = []
i = 0
for node in tree.traverse('levelorder'):
    if node.is_leaf():
        print(node.name)
        leafidx.append(i)
    else:
        inneridx.append(i)
    i+=1
print(leafidx)
print(inneridx)


In [None]:
# settings 
simseed = 78241558624040307
MCMC_iter = 6000
burnin = 2000
nthin = 1 # see from script/running conditions, not used for plotting
folder_runs = f'../results/_sim-30-leaves/runs_v2/{simseed}/' 
folder_simdata = f'../results/_sim-30-leaves/simdata/{simseed}/' 
nnodes = 59
#levelorder_tree = Tree('chazot_subtree_levelorder.nw') 
nxd = 40
pars_name = ['kalpha', 'gtheta']
rep_path = len(pars_name)+1
chains = os.listdir(folder_runs) # use all chains in data seed folder 
chains = [c for c in chains if c[0] not in ['_', '.']] # remove files starting with underscore
print(chains)

# read true parameters
true_pars = [np.genfromtxt(folder_simdata +p+"_sim.csv", delimiter = ",") for p in pars_name]
true_pars


# PLOT TRACE AND DENSITY FOR PARAMETERS
# wait with array in case of irregular dimensions 
temp_name = ['' for i in range(len(chains))]
raw_pars = [[np.genfromtxt(folder_runs + chains[i]+'/'+temp_name[i]+par+"s.csv", delimiter = ",") for i in range(len(chains))] for par in pars_name]
raw_acceptpars = [[np.genfromtxt(folder_runs + chains[i]+'/'+temp_name[i]+"acceptkalpha.csv", delimiter = ",") for i in range(len(chains))] for par in pars_name]

pars = [np.array([raw_pars[j][i][burnin:MCMC_iter] for i in range(len(raw_pars[0]))]) for j in range(len(raw_pars))]
[p.shape for p in pars]
acceptpars = [np.array([raw_acceptpars[j][i][burnin:MCMC_iter] for i in range(len(raw_acceptpars[0]))]) for j in range(len(raw_acceptpars))]
[ap.shape for ap in acceptpars]


In [None]:
# get rhat for parameters 
parsdict = dict(zip(pars_name, pars)) 
MCMC_result = parsdict #parsdict|innernodedict
parsres = az.convert_to_dataset(MCMC_result)
rhat = az.rhat(parsres)
mcse = az.mcse(parsres)
ess = az.ess(parsres)
az.summary(parsres)

# save rhat for plotting
rhats_par = np.array([rhat['kalpha'], rhat['gtheta']])

In [None]:
pars_name = ['kalpha', 'gtheta']
true_vals = {}

for par in pars_name:
    file_path = os.path.join(folder_simdata, f"{par}_sim.csv")
    true_vals[par] = np.genfromtxt(file_path, delimiter=",")
    
print(true_vals)
true_vals = list(true_vals.values()) #[true_pars]
true_vals

In [None]:
#true_vals = true_pars.values() #[true_pars]
import matplotlib.ticker as mticker

plt.rcParams.update({'font.size': 20})
keys = pars_name
colors = sns.color_palette('muted', len(chains))
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(25,15))
p = 0
cp = 0
name = ['$\alpha$']
for i, ax in zip(range(len(axes.flat)), axes.flat): 
        if i%2 == 0: 
            for j in range(pars[p].shape[0]): #loop over chains 
                ax.plot(pars[p][j,:], color=colors[j], alpha=0.3)
            ax.hlines(y=true_vals[p], xmin=0, xmax=pars[p].shape[1], color='skyblue', linewidth=3)
            #ax.set_title(f'{keys[p]}, rhat: {round(float(np.array(rhat[keys[p]])),2)} \n (ess: {round(float(np.array(ess[keys[p]])),2)}) ')
            p+=1
        if i%2==1:
            for j in range(pars[cp].shape[0]):
                sns.kdeplot(pars[cp][j,:], ax=ax, color=colors[j])
                #sns.rugplot(pars[cp][j,:], ax=ax)
            ax.axvline(x = true_vals[cp], ymin = 0, ymax = 1, color='skyblue', linewidth=3) 
            ax.set_ylabel("")  # Remove y-axis label
            ax.set(ylabel=None)  # Also remove label if set by seaborn
            #ax.set_title(f'{keys[cp]}, rhat: {round(float(np.array(rhat[keys[cp]])),2)} \n (ess: {round(float(np.array(ess[keys[cp]])),2)}) ')#
            cp+=1
        #if i%3==2:
        #    ax.scatter(list(range(len(all_rhats_pars[:,cp]))), all_rhats_pars[:,cp])
        #    ax.hlines(y=1.1, xmin=0, xmax=len(all_rhats_pars[:,cp]), color='red', linestyle='dashed', linewidth=2)
        #    ax.hlines(y=1.0, xmin=0, xmax=len(all_rhats_pars[:,cp]), color='green', linestyle='dashed', linewidth=2)
        #    cp+=1

#fig.suptitle(f"Iter: {MCMC_iter}, Burnin: {burnin} \n", fontsize=15)
#fig.tight_layout()

axes[0,0].set_title(r'$\alpha$ ($\hat{R}$='+ f'{round(float(np.array(rhat[keys[0]])),2)}, ESS={round(float(np.array(ess[keys[0]])))})')
axes[0,1].set_title(r'$\alpha$ ($\hat{R}$='+ f'{round(float(np.array(rhat[keys[0]])),2)}, ESS={round(float(np.array(ess[keys[0]])))})')
axes[0,1].tick_params(axis='x', labelrotation=15)
#axes[0,1].xaxis.set_major_formatter(mticker.ScalarFormatter(useMathText=True))
#axes[0,1].ticklabel_format(style='sci', axis='x', scilimits=(-3,-3))
#axes[0,1].xaxis.get_offset_text().set_fontsize(16)  # Make the offset text larger if needed

axes[1,0].set_title(r'$\sigma$ ($\hat{R}$='+ f'{round(float(np.array(rhat[keys[1]])),2)}, ESS={round(float(np.array(ess[keys[1]])))})')
axes[1,1].set_title(r'$\sigma$ ($\hat{R}$= '+ f'{round(float(np.array(rhat[keys[1]])),2)}, ESS={round(float(np.array(ess[keys[1]])))})')
#axes[1,2].set_title(r'$\sigma$')

plt.subplots_adjust(hspace=0.3)  # Increase vertical space (default is 0.2)
plt.savefig(f'convergence_plot_sim-30-leaves_{simseed}.pdf')

## Plot posterior distribution

In [None]:
import gc

def plot_posterior(flat_trees, inneridx, outpath, flat_true_tree=False, sample_n=50, nxd=40):
    # plot summary of Rhat values 
    n_nodes = len(inneridx)
    # Determine the grid size
    grid_size = int(np.ceil(np.sqrt(n_nodes)))
    # Create subplots
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20), sharex=True, sharey=True, dpi=600)
    #fig.suptitle(f'Samples from posterior (every {sample_n}) for all innernodes', size=20)
    # Flatten the axes array for easy iteration
    axes = axes.flatten()

    # Create a scatter plot for each k
    for i in range(n_nodes):
        idx = inneridx[i]
        innernodes = flat_trees[:,:,idx,:].reshape(-1, nxd)[::sample_n,:]
        inode = innernodes #np.append(innernodes, innernodes[:,0:2],1)
        for j in range(inode.shape[0]):
            axes[i].plot(inode[j,::2], inode[j,1::2], '--.', color='steelblue', alpha=0.3)
        if flat_true_tree is not False:
            true_innernode = flat_true_tree[idx,:]
            tinode = np.concatenate((true_innernode, true_innernode[0:2]))  
            axes[i].plot(tinode[::2], tinode[1::2], '--.', color='black', label='True shape')
        axes[i].set_title(f'Node {idx}', size=20);
        axes[i].set_xticklabels([])
        axes[i].set_yticklabels([])

    # Hide any unused subplots
    for j in range(n_nodes, grid_size * grid_size):
        fig.delaxes(axes[j])

   # plt.tight_layout()
    plt.savefig(outpath, bbox_inches='tight') #
    #fig.subplots_adjust(top=0.95)
    #plt.show()
    plt.close()
    gc.collect()

In [None]:
rep_path = len(pars_name)+1
chains = os.listdir(folder_runs) # use all chains in data seed folder 
chains = [c for c in chains if c[0] not in ['_', '.']] # remove files starting with underscore
print(chains)
temp_name = ['' for i in range(len(chains))]

In [None]:
simtree = "../data/chazot_full_tree.nw"
tree = Tree(simtree)
leafidx = []
inneridx = []
i = 0
for node in tree.traverse('levelorder'):
    if node.is_leaf():
        print(node.name)
        leafidx.append(i)
    else:
        inneridx.append(i)
    i+=1
print(leafidx)
print(inneridx)

In [None]:
# read in data and MCMC chains
raw_trees = [np.genfromtxt(folder_runs + chains[i]+'/'+temp_name[i]+"tree_nodes.csv", delimiter = ",") for i in range(len(chains))]
tree_counters = [np.genfromtxt(folder_runs + chains[i]+'/'+temp_name[i]+"tree_counter.csv", delimiter = ",").astype(int) for i in range(len(chains))]


In [None]:
flat_trees_raw = [raw_trees[i].reshape(len(tree_counters[i]),nnodes,nxd) for i in range(len(raw_trees))]
flat_true_tree = np.genfromtxt(folder_simdata+"flat_true_tree.csv", delimiter = ",")
super_root = [np.genfromtxt(folder_runs + chains[i]+'/'+temp_name[i]+"inference_root_start.csv", delimiter = ",") for i in range(len(chains))]
_super_root = [np.concatenate((super_root[i], super_root[i][0:2])) for i in range(len(chains))]
_super_root = np.unique(np.array(_super_root), axis=0)
flat_trees = np.array([np.repeat(flat_trees_raw[i], tree_counters[i], axis=0)[burnin*rep_path:(MCMC_iter//nthin)*rep_path] for i in range(len(flat_trees_raw))])
flat_trees.shape

In [None]:
plot_posterior(flat_trees, inneridx, outpath = f'posterior_samples_burnin={burnin}_MCMCiter={MCMC_iter}.png', flat_true_tree=flat_true_tree, sample_n=50)
