In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
%matplotlib inline
#%config InlineBackend.figure_formats = ['svg']

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
import continuous_parameter_models
import optimizers
import sbn

import importlib 
importlib.reload(continuous_parameter_models)
importlib.reload(optimizers)
importlib.reload(sbn)

from continuous_parameter_models import TFContinuousParameterModel

The following can be used in place of the phylogenetic likelihood for testing.

In [None]:
alpha = 2.
beta = 5.
gamma = tfd.Gamma(concentration=alpha, rate=beta)

def grad_log_like(x):
    with tf.GradientTape() as g:
        tf_x = tf.constant(x, dtype=tf.float32)
        g.watch(tf_x)
        return g.gradient(gamma.log_prob(tf_x), tf_x).numpy()

def log_like(x):
    return gamma.log_prob(x)

In [None]:
inst = sbn.instance("charlie")
data = "primates"
if data == "DS1":
    inst.read_newick_file("../data/ds1.raxml.tre")
    inst.read_fasta_file("../data/DS1.fasta")
    max_x = 0.02
elif data == "primates":
    inst.read_newick_file("../data/primates.tre")
    inst.read_fasta_file("../data/primates.fasta")
    max_x = 0.2
elif data == "hello":
    inst.tree_collection = sbn.TreeCollection(
        [sbn.Tree.of_parent_id_vector([3, 3, 3])],
        ["mars", "saturn", "jupiter"])
    inst.read_fasta_file('../data/hello.fasta')
    max_x = 0.4
else:
    assert False

inst.make_beagle_instances(1)
branch_lengths_extended = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
# Here we are getting a slice that excludes the last (fake) element. 
# Thus we can just deal with the actual branch lengths.
branch_lengths = branch_lengths_extended[:len(branch_lengths_extended)-1]

if data == "hello":
    branch_lengths_extended[:] = np.array([0.2, 0.07, 0.07, 0.])

In [None]:
def log_like_with(in_branch_lengths, grad=False):
    global branch_lengths
    saved_branch_lengths = branch_lengths.copy()
    branch_lengths[:] = in_branch_lengths
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)
    else:
        result = np.array(inst.log_likelihoods())[0]
        branch_lengths[:] = saved_branch_lengths
    return result

def phylo_log_like(x_arr):
    """
    Calculate phylogenetic log likelihood for each of the branch length
    assignments laid out along axis 1.
    """
    return np.apply_along_axis(log_like_with, 1, x_arr)

def grad_phylo_log_like(x_arr):
    return np.apply_along_axis(lambda x: log_like_with(x, grad=True), 1, x_arr)[:,:-2]

In [None]:
def exponential_factory(params):
    return tfd.Exponential(rate=params[:,0])

def gamma_factory(params):
    return tfd.Gamma(concentration=params[:,0], rate=params[:,1])

def inverse_gamma_factory(params):
    return tfd.InverseGamma(concentration=params[:,0], scale=params[:,1])

def lognormal_factory(params):
    return tfd.LogNormal(loc=params[:,0], scale=params[:,1])

def truncated_lognormal_factory(params):
    exp_shift = tfp.bijectors.Chain(
        [tfp.bijectors.AffineScalar(shift=-params[:,2]), tfp.bijectors.Exp()])
    return tfd.TransformedDistribution(
        distribution=tfd.TruncatedNormal(
            loc=params[:,0], 
            scale=params[:,1], 
            low=tf.math.log(params[:,2]), high=999), 
        bijector=exp_shift, 
        name="TruncatedLogNormal")

m = TFContinuousParameterModel(gamma_factory, np.array([4., 20.]), len(branch_lengths), 100)
#m = TFContinuousParameterModel(lognormal_factory, np.array([-2., 0.5]), len(branch_lengths), 100)
#m = TFContinuousParameterModel(truncated_lognormal_factory, np.array([-1., 0.5, 0.1]), len(branch_lengths), 100)
m.mode_match(branch_lengths)
m.set_step_size()
m.elbo_estimate(phylo_log_like, particle_count=1000)

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
m.step_size

In [None]:
history = []
trace = []
window_size = 5
stepsize_increasing_rate = 1.2
stepsize_decreasing_rate = 1-5e-3
stepsize_increasing = True
best_elbo = -np.inf
best_param_matrix = np.zeros(m.param_matrix.shape)
for step in range(30):
    if stepsize_increasing and step >= 2*window_size:
        last_epoch = trace[-window_size:]
        prev_epoch = trace[-2*window_size:-window_size]
        if np.mean(last_epoch) < np.mean(prev_epoch):
            np.copyto(m.param_matrix, best_param_matrix)
            m.step_size /= 2
            stepsize_increasing = False
            print("\nturning around decreasing")
    if stepsize_increasing:
        m.step_size *= stepsize_increasing_rate
    else:
        m.step_size *= stepsize_decreasing_rate
    print(m.step_size, end='')
    m.sample_and_prep_gradients()
    if not m.gradient_step(grad_phylo_log_like(m.z), history):
        np.copyto(m.param_matrix, best_param_matrix)
        m.step_size /= 2
        stepsize_increasing = False
        print("\nturning around nan")
    trace.append(m.elbo_estimate(phylo_log_like, particle_count=500))
    if trace[-1] > best_elbo:
        best_elbo = trace[-1]
        np.copyto(best_param_matrix, m.param_matrix)
pd.Series(trace).plot.line()

In [None]:
stat = []
window_size = 5
for i in range(10, len(trace)):
    trace = np.stack(history)[:i,-1]
    last_epoch = trace[-window_size:]
    prev_epoch = trace[-2*window_size:-window_size]
    stat.append(np.mean(last_epoch) - np.mean(prev_epoch))
    m.elbo_estimate(phylo_log_like, particle_count=500)
pd.Series(stat).plot()

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

lognormal: -7153.707223708818

* good: -2934.5720311158702
* gamma with 200 steps: -2935.098012362405
* lognormal with 200 steps: -2934.699239103927

In [None]:
m.plot(phylo_log_like, max_x=max_x)

---

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

In [None]:
pd.DataFrame(np.stack(history)[:,-1], columns=["elbo"]).plot.line()

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
m.elbo_estimate(phylo_log_like, particle_count=1000)

---

In [None]:
with tf.GradientTape() as g:
    tf_params = tf.constant(np.array([[-2., 0.5]]), dtype=tf.float32)
    g.watch(tf_params)
    q_distribution = lognormal_factory(tf_params)
    mode_error = (0.025 - q_distribution.mode())**2
    grad = g.gradient(mode_error, tf_params)
grad

In [None]:
m.optimizer = optimizers.SGD_Server({"params": m.param_matrix.shape})
# %%timeit -n 1
pd.DataFrame({"bl": np.log(branch_lengths), "y": m.param_matrix[:,1]}).plot.scatter(x="bl", y="y")

In [None]:
    history[-1] = np.append(history[-1], m.elbo_estimate(phylo_log_like, particle_count=500))
