In [None]:
import branch_length
import distributions
import optimizers

import altair as alt
import numpy as np
import pandas as pd
import scipy.stats as stats

alt.renderers.enable("notebook")

import importlib 
importlib.reload(branch_length)
importlib.reload(distributions)
importlib.reload(optimizers)

In [None]:
x_vals = np.linspace(0.1, 5, 40)
x_vals_transpose = np.transpose(np.array([x_vals]))

def plot_functions(f_true, f_approx):
    x_transpose = np.transpose(np.array([x_vals]))
    data = pd.DataFrame({"x": x_vals, "truth": f_true(x_vals_transpose), 
                         "approx": f_approx(x_vals_transpose)})
    return alt.Chart(data.melt(id_vars=["x"])).mark_line().encode(
        x='x',
        y='value',
        color='variable'
    )

def kl_div(f_true, f_approx):
    return {
        "standard": stats.entropy(f_true(x_vals_transpose), f_approx(x_vals_transpose)),
        "reversed": stats.entropy(f_approx(x_vals_transpose), f_true(x_vals_transpose))}

In [None]:
branch_length_param_count = 1
infer_opt = optimizers.SGD_Server(
    {'loc': branch_length_param_count, 'shape': branch_length_param_count})
stepsz = 0.05
clip = 100.
stepsz_dict = {'loc': stepsz, 'shape': stepsz}

# Distribution we wish to approximate-- we pretend this is the "phylogenetic" distribution.
d = distributions.Normal(1)
true_loc = np.array([2.])
true_shape = np.array([1.])
phylo_log_like = lambda x: d.log_prob(x, true_loc, true_shape)
phylo_log_like_grad = lambda x: d.log_prob_grad(x, true_loc, true_shape)

# Variational distribution
q = distributions.Normal(1)
loc = np.array([1.])
shape = np.array([0.5])
q_log_like = lambda x: q.log_prob(x, loc, shape)

In [None]:
def gradient_step(sample_count):
    global loc, shape
    x = q.sample(loc, shape, sample_count)
    weights = branch_length.like_weights(q, phylo_log_like(x), x, loc, shape, clip)
    loc_grad, shape_grad = branch_length.param_grad(q, weights, phylo_log_like_grad(x), x, loc, shape)
    if clip:
        loc_grad = np.clip(loc_grad, -clip, clip)
        shape_grad = np.clip(shape_grad, -clip, clip)
    update_dict = infer_opt.adam(stepsz_dict, {'loc': loc, 'shape': shape}, 
                                 {'loc': loc_grad, 'shape': shape_grad})
    loc += update_dict['loc']
    shape += update_dict['shape']
    return kl_div(phylo_log_like, q_log_like)

In [None]:
loc = np.array([0.])
shape = np.array([1.])
results = [kl_div(phylo_log_like, q_log_like)]

for _ in range(100):
    results.append(gradient_step(10))

plot_data = pd.DataFrame(results).reset_index()
    
alt.Chart(
    pd.melt(plot_data, id_vars=["index"], var_name="variant", value_name="KL divergence")
    ).mark_line().encode(
        alt.X("index"),
        alt.Y("KL divergence",
              scale=alt.Scale()),
        color="variant"
    )

In [None]:
plot_functions(phylo_log_like, q_log_like)

In [None]:
loc

In [None]:
shape

In [None]:
x = q.sample(loc, shape, 100)
weights = branch_length.like_weights(q, phylo_log_like(x), x, loc, shape, clip)
branch_length.param_grad(q, weights, phylo_log_like_grad(x), x, loc, shape)

In [None]:
df = pd.DataFrame({"x": x[:,0], "weights": weights})
df.sort_values("x", inplace=True)

In [None]:
alt.Chart(df).mark_line().encode(
        alt.X("x"),
        alt.Y("weights")
    )

In [None]:
if False:
    d = distributions.Gamma(1)
    alpha = np.array([2.])
    beta = np.array([2.])
    phylo_log_like = lambda x: d.log_prob(x, alpha, beta)
    phylo_log_like_grad = lambda x: d.log_prob_grad(x, alpha, beta)

