In [None]:
import branch_length
import distributions
import optimizers

import altair as alt
import numpy as np
import pandas as pd
import scipy.stats as stats

alt.renderers.enable("notebook")

import importlib 
importlib.reload(branch_length)
importlib.reload(distributions)
importlib.reload(optimizers)

In [None]:
from pytest import approx
import tensorflow as tf
import tensorflow_probability as tfp

tf.enable_eager_execution()

### Validating gradient calculation

In this section we validate the multi-sample reparametrization gradient of log ratio of p/q.

In [None]:
true_loc_val = -2.
true_shape_val = 1.4
true_normal = tfp.distributions.Normal(loc=true_loc_val, scale=true_shape_val)
std_normal = tfp.distributions.Normal(loc=0., scale=1.)

epsilon = tf.constant([-0.1, 0.14, -0.51, -2., 1.8])

with tf.GradientTape() as g:
    mu = tf.constant(1.1)
    sigma = tf.constant(1.)
    g.watch(mu)
    g.watch(sigma)

    tf_x = mu + sigma * epsilon
    y = tf.math.log(
        # In principle we should have a 1/K term here, but it disappears in the log grad.
        tf.math.reduce_sum(true_normal.prob(tf_x) / 
                           tfp.distributions.Normal(loc=mu, scale=sigma).prob(tf_x)))

x = np.array([tf_x.numpy()]).transpose()
tf_gradient = [grad.numpy() for grad in g.gradient(y, [mu, sigma])]

In [None]:
# Distribution we wish to approximate-- we pretend this is the "phylogenetic" distribution.
d = distributions.Normal(1)
true_loc = np.array([true_loc_val])
true_shape = np.array([true_shape_val])
phylo_log_like = lambda x: d.log_prob(x, true_loc, true_shape)
phylo_log_like_grad = lambda x: d.log_prob_grad(x, true_loc, true_shape)

# Variational distribution
q = distributions.Normal(1)
loc = np.array([1.1])
shape = np.array([1.])
q_log_like = lambda x: q.log_prob(x, loc, shape)

def complete_grad(x, loc, shape):
    weights = branch_length.like_weights(q, phylo_log_like(x), x, loc, shape)
    return branch_length.param_grad(q, weights, phylo_log_like_grad(x), x, loc, shape)

loc_grad, shape_grad = complete_grad(x, loc, shape)

Here we can see that the tensorflow gradient is equal to the hand-calculated gradient.

In [None]:
assert tf_gradient[0] == approx(loc_grad, rel=1e-5)
assert tf_gradient[1] == approx(shape_grad, rel=1e-5)
tf_gradient, loc_grad, shape_grad

### Running gradient ascent

Here we actually run the gradient ascent.

In [None]:
branch_length_param_count = 1
sgd_server_args = {'loc': branch_length_param_count, 'shape': branch_length_param_count}
step_count = 400
stepsz = 0.02
stepsz_dict = {'loc': stepsz, 'shape': stepsz}

def gradient_step(infer_opt, sample_count):
    global loc, shape
    x = q.sample(loc, shape, sample_count)
    loc_grad, shape_grad = complete_grad(x, loc, shape)
    update_dict = infer_opt.adam(stepsz_dict, {'loc': loc, 'shape': shape}, 
                                 {'loc': loc_grad, 'shape': shape_grad})
    loc += update_dict['loc']
    shape += update_dict['shape']
    return kl_div(phylo_log_like, q_log_like)

This is some plotting code that can be ignored.

In [None]:
x_vals = np.linspace(-5, 5, 40)
x_vals_transpose = np.transpose(np.array([x_vals]))

def plot_functions(f_true, f_approx):
    x_transpose = np.transpose(np.array([x_vals]))
    data = pd.DataFrame({"x": x_vals, "truth": f_true(x_vals_transpose), 
                         "approx": f_approx(x_vals_transpose)})
    return alt.Chart(data.melt(id_vars=["x"])).mark_line().encode(
        x='x',
        y='value',
        color='variable'
    )

def kl_div(f_true, f_approx):
    return {
        "standard": stats.entropy(f_true(x_vals_transpose), f_approx(x_vals_transpose)),
        "reversed": stats.entropy(f_approx(x_vals_transpose), f_true(x_vals_transpose))}

This runs the optimization and collects the results to plot.

In [None]:
def run_optimization(step_count, particle_count):
    global loc, shape
    loc = np.array([0.5])
    shape = np.array([1.])
    results = [kl_div(phylo_log_like, q_log_like)]
    infer_opt = optimizers.SGD_Server(sgd_server_args)

    for _ in range(step_count):
        results.append(gradient_step(infer_opt, particle_count))

    plot_data = pd.DataFrame(results).reset_index()

    return alt.Chart(
        pd.melt(plot_data, id_vars=["index"], var_name="variant", value_name="KL divergence")
        ).mark_line().encode(
            alt.X("index"),
            alt.Y("KL divergence",
                  scale=alt.Scale()),
            color="variant",
            tooltip=['index', 'KL divergence']
        ).interactive()

Here we see that a single particle converges quickly.

In [None]:
run_optimization(step_count=step_count, particle_count=1)

The fit is quite reasonable.

In [None]:
plot_functions(phylo_log_like, q_log_like)

On the other hand the fitting procedure for more particles is a lot worse.

In [None]:
run_optimization(step_count=step_count, particle_count=10)

In [None]:
plot_functions(phylo_log_like, q_log_like)

In [None]:
run_optimization(step_count=0, particle_count=10)
plot_functions(phylo_log_like, q_log_like)

In [None]:
def sample_gradients(particle_count, sample_count):
    def sample_gradient():
        x = q.sample(loc, shape, 1)
        loc_grad, shape_grad = complete_grad(x, loc, shape)
        pcs = str(particle_count)
        return {"loc_grad_"+pcs: loc_grad[0], "shape_grad_"+pcs: shape_grad[0]}
    return pd.DataFrame([sample_gradient() for _ in range(sample_count)])

sample_count = 5000
raw = pd.concat([sample_gradients(1, sample_count), sample_gradients(10, sample_count)], axis=1)

def compare_pair(key1, key2):
    compare = pd.DataFrame({k: raw[k].sort_values().reset_index(drop=True) for k in [key1, key2]})
    scale = alt.Scale(domain=[compare.values.min(), compare.values.max()])
    return alt.Chart(compare, width=500, height=500).mark_point().encode(
            alt.X(key1, scale=scale),
            alt.Y(key2, scale=scale)
    )
compare_pair("loc_grad_1", "loc_grad_10")

In [None]:
raw.describe()

In [None]:
compare_pair("shape_grad_1", "shape_grad_10")

In [None]:
def compare_pair2(key1, key2):
    compare = pd.DataFrame({k: raw[k].sort_values().reset_index(drop=True) for k in [key1, key2]})
    compare['difference']= compare[key1]-compare[key2]
    compare['ratio']= compare[key1]/compare[key2]
    return compare
compare_loc = compare_pair2("loc_grad_1", "loc_grad_10").reset_index()
alt.Chart(compare_loc, width=500, height=500)\
    .mark_point(clip=True).encode(
        alt.X("index"),
        alt.Y("difference", scale=alt.Scale(domain=(-0.3, 0.3))))

In [None]:
raw.describe()

In [None]:
alt.Chart(compare_loc, width=500, height=500)\
    .mark_point(clip=True).encode(
        alt.X("index"),
        alt.Y("ratio", scale=alt.Scale(domain=(0, 2))))

In [None]:
compare_shape = compare_pair2("shape_grad_1", "shape_grad_10").reset_index()
alt.Chart(compare_shape, width=500, height=500)\
    .mark_point(clip=True).encode(
        alt.X("index"),
        alt.Y("difference", scale=alt.Scale(domain=(-0.2, 0.1))))

In [None]:
alt.Chart(compare_shape, width=500, height=500)\
    .mark_point(clip=True).encode(
        alt.X("index"),
        alt.Y("ratio", scale=alt.Scale(domain=(0, 2))))