In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
%matplotlib inline
#%config InlineBackend.figure_formats = ['svg']

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
import continuous_parameter_models as models
import optimizers
import sbn

import importlib 
importlib.reload(models)
importlib.reload(optimizers)
importlib.reload(sbn)

from continuous_parameter_models import TFContinuousParameterModel

In [None]:
inst = sbn.instance("charlie")
data = "primates"
if data == "DS1":
    inst.read_newick_file("../data/ds1.raxml.tre")
    inst.read_fasta_file("../data/DS1.fasta")
    max_x = 0.02
elif data == "primates":
    inst.read_newick_file("../data/primates.tre")
    inst.read_fasta_file("../data/primates.fasta")
    max_x = 0.2
elif data == "hello":
    inst.tree_collection = sbn.TreeCollection(
        [sbn.Tree.of_parent_id_vector([3, 3, 3])],
        ["mars", "saturn", "jupiter"])
    inst.read_fasta_file('../data/hello.fasta')
    max_x = 0.4
else:
    assert False

inst.make_beagle_instances(1)
branch_lengths_extended = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
# Here we are getting a slice that excludes the last (fake) element. 
# Thus we can just deal with the actual branch lengths.
branch_lengths = branch_lengths_extended[:len(branch_lengths_extended)-1]

if data == "hello":
    branch_lengths_extended[:] = np.array([0.2, 0.07, 0.07, 0.])

In [None]:
def log_exp_prior(x, rate=10):
    return np.log(rate) - np.sum(rate*x, axis=1)

def grad_log_exp_prior(x, rate=10):
    return -rate

In [None]:
def log_like_with(in_branch_lengths, grad=False):
    global branch_lengths
    saved_branch_lengths = branch_lengths.copy()
    branch_lengths[:] = in_branch_lengths
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)
    else:
        result = np.array(inst.log_likelihoods())[0]
        branch_lengths[:] = saved_branch_lengths
    return result

def phylo_log_like(x_arr):
    """
    Calculate phylogenetic log likelihood for each of the branch length
    assignments laid out along axis 1.
    """
    return np.apply_along_axis(log_like_with, 1, x_arr)

def grad_phylo_log_like(x_arr):
    return np.apply_along_axis(lambda x: log_like_with(x, grad=True), 1, x_arr)[:,:-2]

def phylo_log_upost(x_arr):
    """
    The unnormalized phylogenetic posterior with an Exp(10) prior.
    """
    return phylo_log_like(x_arr) + log_exp_prior(x_arr)

def grad_phylo_log_upost(x_arr):
    """
    The unnormalized phylogenetic posterior with an Exp(10) prior.
    """
    return grad_phylo_log_like(x_arr) + grad_log_exp_prior(x_arr)

In [None]:
#m = TFContinuousParameterModel(models.gamma_factory, np.array([4., 20.]), len(branch_lengths), 100)
#m = TFContinuousParameterModel(models.lognormal_factory, np.array([-2., 0.5]), len(branch_lengths), 100)
m = TFContinuousParameterModel(models.truncated_lognormal_factory, np.array([-1., 0.5, 0.1]), len(branch_lengths), 100)
m.mode_match(branch_lengths)
m.elbo_estimate(phylo_log_upost, particle_count=1000)

In [None]:
opt = optimizers.AdaptiveStepsizeOptimizer(m)

In [None]:
opt.gradient_steps(phylo_log_upost, grad_phylo_log_upost, 25)
opt.plot_trace()

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
importlib.reload(sbn)

mb_inst = sbn.instance("mb")
mb_inst.read_nexus_file("../_ignore/mb/DS1_out.t")
mb_inst.process_loaded_trees()

In [None]:
mb_inst.branch_lengths_by_split()

In [None]:
sbn.instance.__dict__

In [None]:
a = np.array([rootsplit[0] for rootsplit in mb_inst.get_indexer_representations()])
assert (a == a[0,:]).all()
a[0]

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

lognormal: -7153.707223708818

* good: -2934.5720311158702
* gamma with 200 steps: -2934.6846971854598
* lognormal with 200 steps: -2934.636675260846

In [None]:
m.plot(phylo_log_like, max_x=max_x)

---

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

---

In [None]:
with tf.GradientTape() as g:
    tf_params = tf.constant(np.array([[-2., 0.5]]), dtype=tf.float32)
    g.watch(tf_params)
    q_distribution = lognormal_factory(tf_params)
    mode_error = (0.025 - q_distribution.mode())**2
    grad = g.gradient(mode_error, tf_params)
grad

In [None]:
m.optimizer = optimizers.SGD_Server({"params": m.param_matrix.shape})
# %%timeit -n 1
pd.DataFrame({"bl": np.log(branch_lengths), "y": m.param_matrix[:,1]}).plot.scatter(x="bl", y="y")

In [None]:
    history[-1] = np.append(history[-1], m.elbo_estimate(phylo_log_like, particle_count=500))


In [None]:
alpha = 2.
beta = 5.
gamma = tfd.Gamma(concentration=alpha, rate=beta)

def grad_log_like(x):
    with tf.GradientTape() as g:
        tf_x = tf.constant(x, dtype=tf.float32)
        g.watch(tf_x)
        return g.gradient(gamma.log_prob(tf_x), tf_x).numpy()

def log_like(x):
    return gamma.log_prob(x)