In [None]:
import branch_length
import distributions
import optimizers
import sbn

import altair as alt
import numpy as np
import pandas as pd
import scipy.stats as stats

alt.renderers.enable("notebook")
#def my_theme(*args, **kwargs):
#    return {"background": "white"}
          
#alt.themes.register('my_theme', my_theme)
#alt.themes.enable('my_theme')

import importlib 
importlib.reload(branch_length)
importlib.reload(distributions)
importlib.reload(optimizers)

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

In [None]:
inst = sbn.instance("charlie")
inst.tree_collection = sbn.TreeCollection(
    [sbn.Tree.of_parent_id_vector([3, 3, 3])],
    ["mars", "saturn", "jupiter"])
inst.read_fasta_file('../data/hello.fasta')
inst.make_beagle_instances(2)
branch_lengths = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
branch_lengths[:] = np.array([0.1, 0.1, 0.3, 0.])

def log_like_with(branch_id: int, branch_length: float, grad=False):
    saved_branch_length = branch_lengths[branch_id]
    branch_lengths[branch_id] = branch_length
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)[branch_id]
    else:
        result = np.array(inst.log_likelihoods())[0]
    branch_lengths[branch_id] = saved_branch_length
    return result

def phylo_log_like(x_arr):
    return np.array([log_like_with(2, x) for x in x_arr])

def phylo_log_like_grad(x_arr):
    return np.array([[log_like_with(2, x, grad=True)] for x in x_arr])

x_vals = np.linspace(0, 0.3, 100)
df = pd.DataFrame({"x": x_vals, "y": phylo_log_like(x_vals)})
alt.Chart(df).mark_line().encode(
        alt.X("x"),
        alt.Y("y", scale=alt.Scale(zero=False))
    )

In [None]:
# Fake "phylogenetic" distributions.
if False:
    d = distributions.LogNormal(1)
    true_loc = np.array([-2.3])
    true_shape = np.array([-0.6])
    phylo_log_like = lambda x: d.log_prob(x, true_loc, true_shape)
    phylo_log_like_grad = lambda x: d.log_prob_grad(x, true_loc, true_shape)
if False:
    d = distributions.Gamma(1)
    alpha = np.array([2.])
    beta = np.array([5.])
    phylo_log_like = lambda x: d.log_prob(x, alpha, beta)
    phylo_log_like_grad = lambda x: d.log_prob_grad(x, alpha, beta) * np.exp(x)

In [None]:
# These are values for plotting and for calculating KL divergence.
x_vals = np.linspace(0.01, 0.7, 100)
x_vals_transpose = np.transpose(np.array([x_vals]))

transforms = {
    "identity": lambda x: x,
    "normalize": lambda x: x / np.sum(x),
    "exp": lambda x: np.exp(x),
    "exp_normalize": lambda x: np.exp(x) / sum(np.exp(x))
}

def plot_functions(f_true, f_approx, transform="identity"):
    transform = transforms[transform]
    data = pd.DataFrame({"x": x_vals, "truth": transform(f_true(x_vals_transpose)), 
                         "approx": transform(f_approx(x_vals_transpose))})
    return alt.Chart(data.melt(id_vars=["x"])).mark_line().encode(
        x='x',
        y='value',
        color='variable'
    )

def kl_div(f_true, f_approx, transform="identity"):
    transform = transforms[transform]
    f_approx_trans = lambda x: transform(f_approx(x))
    f_true_trans = lambda x: transform(f_true(x))
    return {
        "standard": stats.entropy(f_true_trans(x_vals_transpose), f_approx_trans(x_vals_transpose)),
        "reversed": stats.entropy(f_approx_trans(x_vals_transpose), f_true_trans(x_vals_transpose))}

def elbo(f_true, f_approx):
    approx_log_likes = f_approx(x_vals_transpose)
    approx_log_probs = np.exp(approx_log_likes) / np.sum(np.exp(approx_log_likes))
    elbo = np.sum(approx_log_probs * (phylo_log_like(x_vals_transpose) - approx_log_likes))
    return {"elbo": elbo}

In [None]:
branch_length_param_count = 1
sgd_server_args = {'loc': branch_length_param_count, 'shape': branch_length_param_count}
infer_opt = optimizers.SGD_Server(sgd_server_args)
stepsz = 0.005
clip = 5.
anneal_freq = 500
anneal_rate = 0.95
# measured_divergence = lambda: kl_div(phylo_log_like, q_log_like, transform="exp_normalize")
measured_divergence = lambda: elbo(phylo_log_like, q_log_like)

# Variational distribution
q = distributions.LogNormal(1)
loc = np.array([1.])
shape = np.array([0.5])
q_log_like = lambda x: q.log_prob(x, loc, shape)

def multi_elbo_grad(x, loc, shape, clip):
    weights = branch_length.like_weights(q, phylo_log_like(x), x, loc, shape, clip)
    return branch_length.param_grad(q, weights, phylo_log_like_grad(x), x, loc, shape)

In [None]:
def tf_multi_elbo_grad(x, loc, shape, clip):
    true_distribution = tfp.distributions.Gamma(
        concentration=alpha[0], rate=beta[0]
    )
    # A variety of epsilons.
    epsilon = tf.constant(np.random.normal(0., 1., 100))
    with tf.GradientTape() as g:
        tf_loc = tf.constant(loc)
        tf_scale = tf.constant(shape)
        g.watch(tf_loc)
        g.watch(tf_scale)
        tf_x = tf.math.exp(tf_loc + tf_scale * epsilon)
        # This is the log of the full sum of ratios as in the equation just before (7)
        # in the 2018 ICLR paper.
        y = tf.math.log(
            # In principle we should have a 1/K term here, but it disappears in the log
            # grad.
            tf.math.reduce_sum(
                true_distribution.prob(tf_x)
                / tfp.distributions.LogNormal(loc=tf_loc, scale=tf_scale).prob(tf_x)
            )
        )
        x_arr = np.array([tf_x.numpy()]).transpose()
        tf_gradient = [grad.numpy() for grad in g.gradient(y, [tf_loc, tf_scale])]
    return tf_gradient

In [None]:
def gradient_step(sample_count):
    global loc, shape
    x = q.sample(loc, shape, sample_count)
    loc_grad, shape_grad = tf_multi_elbo_grad(x, loc, shape, clip)
    if clip:
        loc_grad = np.clip(loc_grad, -clip, clip)
        shape_grad = np.clip(shape_grad, -clip, clip)
    update_dict = infer_opt.adam(stepsz_dict, {'loc': loc, 'shape': shape}, 
                                 {'loc': loc_grad, 'shape': shape_grad})
    loc += update_dict['loc']
    shape += update_dict['shape']
    return measured_divergence()

In [None]:
loc = np.array([-0.6])
shape = np.array([1.1])
measured_divergence()

In [None]:
loc = np.array([-1.6])
shape = np.array([1.1])
measured_divergence()

In [None]:
plot_functions(phylo_log_like, q_log_like, transform="exp_normalize")

In [None]:
x = q.sample(loc, shape, 100)
multi_elbo_grad(x, loc, shape, clip)

In [None]:
x = q.sample(loc, shape, 100)
#x = np.array([np.exp(loc+0.05)])
loc_grad, shape_grad = multi_elbo_grad(x, loc, shape, clip)
loc_grad = np.clip(loc_grad, -clip, clip)
shape_grad = np.clip(shape_grad, -clip, clip)
print(loc, loc_grad, shape, shape_grad)
loc += stepsz*loc_grad
shape += stepsz*shape_grad
measured_divergence()

In [None]:
x_samples = q.sample(loc, shape, 1000)

gradient_samples = pd.DataFrame({
    "x": [multi_elbo_grad(np.array([[x]]), loc, shape, clip)[0][0] for x in x_samples.transpose()[0]]})
gradient_samples.loc[gradient_samples['x'] < -5,] = -5

alt.Chart(gradient_samples).mark_bar().encode(
    alt.X("x", bin=alt.Bin(maxbins=50)),
    y='count()',
)

In [None]:
gradient_samples.describe()

In [None]:
lognormal_samples = pd.DataFrame({"x": q.sample(loc, shape, 5000).transpose()[0]})
lognormal_samples = lognormal_samples.loc[lognormal_samples['x'] < 0.70,]

alt.Chart(lognormal_samples).mark_bar().encode(
    alt.X("x", bin=alt.Bin(maxbins=50)),
    y='count()',
)

In [None]:
loc_grad_data = pd.DataFrame({
    "x": x_vals, 
    "loc_grad": [multi_elbo_grad(np.array([[x]]), loc, shape, clip)[0][0] for x in x_vals],
    "pdf": q_log_like(x_vals_transpose)
})
base = alt.Chart(loc_grad_data).mark_line().encode(
        alt.X("x"),
        alt.Y("loc_grad", scale=alt.Scale(zero=False))
    )
base + base.mark_line(color="red").encode(y="pdf")


In [None]:
plot_functions(phylo_log_like, q_log_like, transform="exp_normalize")

In [None]:
# Re-initialize the SGD server.
infer_opt.__init__(sgd_server_args)
stepsz_dict = {'loc': stepsz, 'shape': stepsz}
loc = np.array([-1.])
shape = np.array([1.])
results = [measured_divergence()]

for i in range(2000):
    results.append(gradient_step(100))
    if i % anneal_freq == 0:
        for k in stepsz_dict:
            stepsz_dict[k] *= anneal_rate


plot_data = pd.DataFrame(results).reset_index()
    
alt.Chart(
    pd.melt(plot_data, id_vars=["index"], var_name="measured divergence", value_name="divergence")
    ).mark_line().encode(
        alt.X("index"),
        alt.Y("divergence",
              scale=alt.Scale()),
        color="measured divergence"
    )

In [None]:
plot_functions(phylo_log_like, q_log_like, transform="exp_normalize")

In [None]:
def sample_gradients(particle_count, sample_count):
    def sample_gradient():
        x = q.sample(loc, shape, 1)
        loc_grad, shape_grad = multi_elbo_grad(x, loc, shape, clip)
        pcs = str(particle_count)
        return {"loc_grad_"+pcs: loc_grad[0], "shape_grad_"+pcs: shape_grad[0]}
    return pd.DataFrame([sample_gradient() for _ in range(sample_count)])

sample_count = 1000
raw = pd.concat([sample_gradients(1, sample_count), sample_gradients(10, sample_count)], axis=1)

raw.describe()