In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
%matplotlib inline
%config InlineBackend.figure_formats = ['svg']

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
import optimizers
import sbn

The following can be used in place of the phylogenetic likelihood for testing.

In [None]:
alpha = 2.
beta = 5.
gamma = tfd.Gamma(concentration=alpha, rate=beta)

def grad_log_like(x):
    with tf.GradientTape() as g:
        tf_x = tf.constant(x, dtype=tf.float32)
        g.watch(tf_x)
        return g.gradient(gamma.log_prob(tf_x), tf_x).numpy()

def log_like(x):
    return gamma.log_prob(x)

In [None]:
inst = sbn.instance("charlie")
inst.tree_collection = sbn.TreeCollection(
    [sbn.Tree.of_parent_id_vector([3, 3, 3])],
    ["mars", "saturn", "jupiter"])
inst.read_fasta_file('../data/hello.fasta')
inst.make_beagle_instances(2)
branch_lengths_extended = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
branch_lengths_extended[:] = np.array([0.1, 0.1, 0.3, 0.])
# Here we are getting a slice that excludes the last (fake) element. 
# Thus we can just deal with the actual branch lengths.
branch_lengths = branch_lengths_extended[:len(branch_lengths_extended)-1]

def log_like_with(in_branch_lengths, grad=False):
    global branch_lengths
    saved_branch_lengths = branch_lengths.copy()
    branch_lengths[:] = in_branch_lengths
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)
    else:
        result = np.array(inst.log_likelihoods())[0]
        branch_lengths[:] = saved_branch_lengths
    return result

def phylo_log_like(x_arr):
    return np.apply_along_axis(log_like_with, 1, x_arr)

def grad_phylo_log_like(x_arr):
    return np.apply_along_axis(lambda x: log_like_with(x, grad=True), 1, x_arr)[:,:-2]

In [None]:
class TFContinuousParameterModel:
    def __init__(self, q_factory, initial_params, branch_count, particle_count, step_size=0.01):
        assert initial_params.ndim == 1
        self.q_factory = q_factory
        self.param_matrix = np.full((branch_count, len(initial_params)), initial_params)
        #self.param_matrix = np.copy(initial_params)
        self.particle_count = particle_count
        self.step_size = step_size
        # The current stored sample.
        self.x = None
        # The gradient of x with respect to the parameters of q.
        self.grad_x = None
        # The stochastic gradient of log sum q for x.
        self.grad_log_sum_q = None
    
    
    @property
    def branch_count(self):
        return self.param_matrix.shape[0]
    

    @property
    def param_count(self):
        return self.param_matrix.shape[1]

        
    def sample(self):
        with tf.GradientTape(persistent=True) as g:
            tf_params = tf.constant(self.param_matrix, dtype=tf.float32)
            g.watch(tf_params)
            q_distribution = self.q_factory(tf_params)
            tf_x = q_distribution.sample(self.particle_count)
            q_term = tf.math.reduce_sum(tf.math.log(q_distribution.prob(tf_x)))
        self.x = tf_x.numpy()
        # The Jacobian is laid out as particles x edges x edges x params.
        self.grad_x = np.sum(g.jacobian(tf_x, tf_params).numpy(), axis=2)
        self.grad_log_sum_q = g.gradient(q_term, tf_params).numpy()
        del g  # Should happen anyway but being explicit to remember.
        return self.x
    
    
    def clear_sample(self):
        self.x = None
        self.grad_x = None
        self.grad_log_sum_q = None

    
    @staticmethod
    def _chain_rule(grad_log_p_x, grad_x):
        return np.tensordot(grad_log_p_x.transpose(), grad_x, axes=1).diagonal().transpose()


    @staticmethod
    def _slow_chain_rule(grad_log_p_x, grad_x):
        particle_count, branch_count, param_count = grad_x.shape
        result = np.zeros((branch_count, param_count))
        for branch in range(branch_count):
            for param in range(param_count):
                for particle in range(particle_count):
                    result[branch, param] += grad_log_p_x[particle, branch] * grad_x[particle, branch, param] 
        return result
    
    
    def linspace_one_branch(self, branch, min_x, max_x, num=50, default=0.1):
        """
        Fill a num x branch_count array with default, except for the branch column,
        which gets filled with a linspace.
        """
        a = np.full((num,self.branch_count), default)
        a[:,branch] = np.linspace(min_x, max_x, num)
        return a
    
    
    def elbo_gradient_using_current_sample(self, grad_log_p_x):
        assert self.grad_x is not None
        assert np.allclose(
            self._chain_rule(grad_log_p_x, self.grad_x), 
            self._slow_chain_rule(grad_log_p_x, self.grad_x))
        unnormalized_result = self._chain_rule(grad_log_p_x, self.grad_x) - self.grad_log_sum_q
        return unnormalized_result / self.particle_count
    
    
    def gradient_step(self, grad_log_p_x, history = None):
        grad = self.elbo_gradient_using_current_sample(grad_log_p_x)
        self.param_matrix += self.step_size * grad
        self.clear_sample()
        if history is not None:
            history.append(self.param_matrix.copy())
    
    
    def plot_1d(self, ax, target_log_like, which_branch, max_x = 0.5):
        min_x = max_x/100
        x_vals = m.linspace_one_branch(which_branch, min_x, max_x, 100)
        q_distribution = self.q_factory(self.param_matrix)
        df = pd.DataFrame({
            "x": x_vals[:,which_branch], 
            "target": sp.special.softmax(target_log_like(x_vals)),
            "fit": sp.special.softmax(q_distribution.log_prob(x_vals).numpy()[:,which_branch])})
        return df.plot(ax=ax, x="x", y=["target", "fit"], kind="line", 
                title=q_distribution._name+" "+str(self.param_matrix[which_branch,:]))

    
    def plot(self, target_log_like, max_x = 0.5):
        f, axarr = plt.subplots(self.branch_count, sharex=True)
        for which_branch in range(self.branch_count):
             self.plot_1d(axarr[which_branch], target_log_like, which_branch, max_x = 0.5)
        plt.tight_layout()

In [None]:
def exponential_factory(params):
    return tfp.distributions.Exponential(rate=params[:,0])

def gamma_factory(params):
    return tfp.distributions.Gamma(concentration=params[:,0], rate=params[:,1])

def inverse_gamma_factory(params):
    return tfp.distributions.InverseGamma(concentration=params[:,0], scale=params[:,1])

def lognormal_factory(params):
    return tfp.distributions.LogNormal(loc=params[:,0], scale=params[:,1])

#m = TFContinuousParameterModel(inverse_gamma_factory, np.array([2., 0.5]), 3, 100, step_size=0.05)
m = TFContinuousParameterModel(gamma_factory, np.array([2., 12.]), 3, 100, step_size=0.05)
#m = TFContinuousParameterModel(lognormal_factory, np.array([-2., 0.5]), 3, 5)

In [None]:
m.plot(phylo_log_like)

In [None]:
history = []
for _ in range(100):
    m.sample()
    m.gradient_step(grad_phylo_log_like(m.x), history)

In [None]:
m.plot(phylo_log_like)