In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
%matplotlib inline
#%config InlineBackend.figure_formats = ['svg']

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
import continuous_parameter_models
import optimizers
import sbn

import importlib 
importlib.reload(continuous_parameter_models)
importlib.reload(optimizers)
importlib.reload(sbn)

from continuous_parameter_models import TFContinuousParameterModel

The following can be used in place of the phylogenetic likelihood for testing.

In [None]:
alpha = 2.
beta = 5.
gamma = tfd.Gamma(concentration=alpha, rate=beta)

def grad_log_like(x):
    with tf.GradientTape() as g:
        tf_x = tf.constant(x, dtype=tf.float32)
        g.watch(tf_x)
        return g.gradient(gamma.log_prob(tf_x), tf_x).numpy()

def log_like(x):
    return gamma.log_prob(x)

In [None]:
inst = sbn.instance("charlie")
data = "hello"
if data == "DS1":
    inst.read_newick_file("../data/ds1.raxml.tre")
    inst.read_fasta_file('../data/DS1.fasta')
    max_x = 0.02
elif data == "hello":
    inst.tree_collection = sbn.TreeCollection(
        [sbn.Tree.of_parent_id_vector([3, 3, 3])],
        ["mars", "saturn", "jupiter"])
    inst.read_fasta_file('../data/hello.fasta')
    max_x = 0.4
else:
    assert False

inst.make_beagle_instances(1)
branch_lengths_extended = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
# Here we are getting a slice that excludes the last (fake) element. 
# Thus we can just deal with the actual branch lengths.
branch_lengths = branch_lengths_extended[:len(branch_lengths_extended)-1]

if data == "hello":
    branch_lengths_extended[:] = np.array([0.1, 0.1, 0.3, 0.])

In [None]:
def log_like_with(in_branch_lengths, grad=False):
    global branch_lengths
    saved_branch_lengths = branch_lengths.copy()
    branch_lengths[:] = in_branch_lengths
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)
    else:
        result = np.array(inst.log_likelihoods())[0]
        branch_lengths[:] = saved_branch_lengths
    return result

def phylo_log_like(x_arr):
    return np.apply_along_axis(log_like_with, 1, x_arr)

def grad_phylo_log_like(x_arr):
    return np.apply_along_axis(lambda x: log_like_with(x, grad=True), 1, x_arr)[:,:-2]

In [None]:
def exponential_factory(params):
    return tfd.Exponential(rate=params[:,0])

def gamma_factory(params):
    return tfd.Gamma(concentration=params[:,0], rate=params[:,1])

def inverse_gamma_factory(params):
    return tfd.InverseGamma(concentration=params[:,0], scale=params[:,1])

def lognormal_factory(params):
    return tfd.LogNormal(loc=params[:,0], scale=params[:,1])

def truncated_lognormal_factory(params):
    exp_shift = tfp.bijectors.Chain(
        [tfp.bijectors.AffineScalar(shift=-params[:,2]), tfp.bijectors.Exp()])
    return tfd.TransformedDistribution(
        distribution=tfd.TruncatedNormal(
            loc=params[:,0], 
            scale=params[:,1], 
            low=tf.math.log(params[:,2]), high=999), 
        bijector=exp_shift, 
        name="TruncatedLogNormal")

m = TFContinuousParameterModel(gamma_factory, np.array([1.3, 50.]), len(branch_lengths), 100, step_size=0.1)
#m = TFContinuousParameterModel(lognormal_factory, np.array([-2., 0.5]), len(branch_lengths), 5)
#m = TFContinuousParameterModel(truncated_lognormal_factory, np.array([-1., 0.5, 0.1]), len(branch_lengths), 100)

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
# %%timeit -n 1

history = []
for _ in range(400):
    m.sample()
    m.gradient_step(grad_phylo_log_like(m.x), history)

In [None]:
m.plot(phylo_log_like, max_x=max_x)

In [None]:
m.sample()
print(m.elbo_gradient_using_current_sample(grad_phylo_log_like(m.x)))
m.clear_sample()