In [None]:
import branch_length
import distributions

import altair as alt
import numpy as np
import pandas as pd
import scipy.stats as stats

alt.renderers.enable("notebook")

import importlib 
importlib.reload(branch_length)
importlib.reload(distributions)

In [None]:
x_vals = np.linspace(0.1, 5, 40)
x_vals_transpose = np.transpose(np.array([x_vals]))

def plot_functions(f_true, f_approx):
    x_transpose = np.transpose(np.array([x_vals]))
    data = pd.DataFrame({"x": x_vals, "truth": f_true(x_vals_transpose), 
                         "approx": f_approx(x_vals_transpose)})
    return alt.Chart(data.melt(id_vars=["x"])).mark_line().encode(
        x='x',
        y='value',
        color='variable'
    )

def kl_div(f_true, f_approx):
    return {
        "standard": stats.entropy(f_true(x_vals_transpose), f_approx(x_vals_transpose)),
        "reversed": stats.entropy(f_approx(x_vals_transpose), f_true(x_vals_transpose))}

In [None]:
stepsz = 0.1
grad_clip = 100.
d = distributions.Normal(1)
true_loc = np.array([2.])
true_shape = np.array([1.])
phylo_log_like = lambda x: d.log_prob(x, true_loc, true_shape)
phylo_log_like_grad = lambda x: d.log_prob_grad(x, true_loc, true_shape)

loc = np.array([1.])
shape = np.array([0.5])
q = lambda x: d.log_prob(x, loc, shape)

In [None]:
def gradient_step(sample_count):
    global loc, shape
    x = d.sample(loc, shape, sample_count)
    x_transpose = np.transpose(np.array([x]))
    weights = branch_length.like_weights(d, phylo_log_like(x), x, loc, shape)
    loc_grad, shape_grad = branch_length.param_grad(d, weights, phylo_log_like_grad(x), x, loc, shape)

    if False:
        print("before")
        print(x)
        print(weights)
        print(loc_grad)
        print(shape_grad)

    loc += np.clip(loc_grad*stepsz, -grad_clip, grad_clip)
    shape += np.clip(shape_grad*stepsz, -grad_clip, grad_clip)

    if False:
        print("after")
        print(true_loc, loc, loc_grad*stepsz)
        print(true_shape, shape, shape_grad*stepsz)
    return kl_div(phylo_log_like, q)
    
plot_functions(phylo_log_like, q)

In [None]:
loc = np.array([0.])
shape = np.array([0.5])
results = [kl_div(phylo_log_like, q)]

for _ in range(40):
    results.append(gradient_step(1))

plot_data = pd.DataFrame(results).reset_index()
    
alt.Chart(
    pd.melt(plot_data, id_vars=["index"], var_name="variant", value_name="KL divergence")
    ).mark_line().encode(
        alt.X("index"),
        alt.Y("KL divergence",
              scale=alt.Scale(domain=(0., 0.6))),
        color="variant"
    )