In [None]:
import numpy as np
import pandas as pd
import sbn
%load_ext rpy2.ipython

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
import continuous_parameter_models as models
import optimizers
import sbn

import importlib 
importlib.reload(models)
importlib.reload(optimizers)
importlib.reload(sbn)

from continuous_parameter_models import TFContinuousParameterModel

In [None]:
data = "DS1"
if data == "DS1":
    nexus_file = "../_ignore/mb/ds1/DS1_out.t"
    fasta_file = "../data/DS1.fasta"
elif data == "primates":
    nexus_file = "../_ignore/mb/primates/primates_out.t"
    fasta_file = "../_ignore/primates.fasta"

inst = sbn.instance("charlie")
inst.read_nexus_file(nexus_file)
inst.process_loaded_trees()

In [None]:
burn_in = int(0.1 * inst.tree_count())
mb_branch_lengths = pd.DataFrame(
    np.array([np.array(a) for a in inst.branch_lengths_by_split()])[:, burn_in:].transpose()
    )
mb_branch_lengths["total"] = mb_branch_lengths.sum(axis=1)

In [None]:
inst.read_fasta_file(fasta_file)
inst.make_beagle_instances(1)
inst.sample_trees(1)
tree = inst.tree_collection.trees[0]
branch_lengths_extended = np.array(tree.branch_lengths, copy=False)
# Here we are getting a slice that excludes the last (fake) element. 
# Thus we can just deal with the actual branch lengths.
branch_lengths = branch_lengths_extended[:len(branch_lengths_extended)-1]
branch_lengths[:] = 0.1
# The ith entry of this array gives the index of the split corresponding to the ith branch.
branch_to_split = np.array(inst.get_psp_indexer_representations()[0][0])
# The ith entry of this array gives the index of the branch corresponding to the ith split.
split_to_branch=np.copy(branch_to_split)
for branch in range(len(branch_to_split)):
    split_to_branch[branch_to_split[branch]] = branch

In [None]:
def translate_branches_to_splits(branch_vector):
    # The ith entry of this array is the entry corresponding to the ith split.
    return branch_vector[split_to_branch]

In [None]:
def log_like_with(split_lengths, grad=False):
    global branch_lengths
    saved_branch_lengths = branch_lengths.copy()
    for branch in range(len(branch_lengths)):
        branch_lengths[branch] = split_lengths[branch_to_split[branch]]
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = translate_branches_to_splits(np.array(log_grad)[:-2])
    else:
        result = np.array(inst.log_likelihoods())[0]
        branch_lengths[:] = saved_branch_lengths
    return result

def phylo_log_like(x_arr):
    """
    Calculate phylogenetic log likelihood for each of the branch length
    assignments laid out along axis 1.
    """
    return np.apply_along_axis(log_like_with, 1, x_arr)

def grad_phylo_log_like(x_arr):
    return np.apply_along_axis(lambda x: log_like_with(x, grad=True), 1, x_arr)

def log_exp_prior(x, rate=10):
    return np.log(rate) - np.sum(rate*x, axis=1)

def grad_log_exp_prior(x, rate=10):
    return -rate

def phylo_log_upost(x_arr):
    """
    The unnormalized phylogenetic posterior with an Exp(10) prior.
    """
    return phylo_log_like(x_arr) + log_exp_prior(x_arr)

def grad_phylo_log_upost(x_arr):
    """
    The unnormalized phylogenetic posterior with an Exp(10) prior.
    """
    return grad_phylo_log_like(x_arr) + grad_log_exp_prior(x_arr)

In [None]:
#m = TFContinuousParameterModel(models.gamma_factory, np.array([1.3, 3.]), len(branch_lengths), 100)
m = TFContinuousParameterModel(models.lognormal_factory, np.array([-2., 0.5]), len(branch_lengths), 100)
#m = TFContinuousParameterModel(models.truncated_lognormal_factory, np.array([-1., 0.5, 0.1]), len(branch_lengths), 100)
#m.mode_match(translate_branches_to_splits(branch_lengths))
m.elbo_estimate(phylo_log_upost, particle_count=1000)

In [None]:
opt = optimizers.AdaptiveStepsizeOptimizer(m)

In [None]:
opt.step_size /= 2

In [None]:
opt.stepsize_increasing = True

In [None]:
opt.step_size, opt.stepsize_increasing

In [None]:
#opt.trace = []
opt.gradient_steps(phylo_log_upost, grad_phylo_log_upost, 100)
opt_trace = pd.DataFrame({"elbo": opt.trace}).reset_index()

In [None]:
%%R -i opt_trace -w 800 -h 400 -u px

library("ggplot2")
library("cowplot")

normal = ggplot(opt_trace) + 
    theme_minimal() +
    geom_line(aes(x=index, y=elbo))

zoomed = ggplot(tail(opt_trace, nrow(opt_trace)/3)) + 
    theme_minimal() +
    geom_line(aes(x=index, y=elbo))

plot_grid(normal, zoomed, nrow=2)

In [None]:
fit_sample = pd.DataFrame(m.sample(len(mb_branch_lengths)))
fit_sample["total"] = fit_sample.sum(axis=1)
fit_sample["type"] = "vb"
mb_branch_lengths["type"] = "mcmc"
plot_fit_df = pd.concat([fit_sample.melt(id_vars="type"), mb_branch_lengths.melt(id_vars="type")])
plot_fit_df["variable"] = plot_fit_df["variable"].astype(str)

In [None]:
%%R -i plot_fit_df -w 1600 -h 800 -u px

library("ggplot2")
library("cowplot")

ggplot(plot_fit_df) + 
    theme_minimal_grid() +
    theme(axis.text.x = element_text(angle = -25)) +
    geom_density(aes(value, color=type)) +
    facet_wrap("variable", scales="free")

In [None]:
m.elbo_estimate(phylo_log_upost, particle_count=5000)