In [None]:
import branch_length
import distributions
import optimizers
import sbn

import altair as alt
import numpy as np
import pandas as pd
import scipy.stats as stats

alt.renderers.enable("notebook")

import importlib 
importlib.reload(branch_length)
importlib.reload(distributions)
importlib.reload(optimizers)

In [None]:
inst = sbn.instance("charlie")
inst.tree_collection = sbn.TreeCollection(
    [sbn.Tree.of_parent_id_vector([3, 3, 3])],
    ["mars", "saturn", "jupiter"])
inst.read_fasta_file('../data/hello.fasta')
inst.make_beagle_instances(2)
branch_lengths = np.array(inst.tree_collection.trees[0].branch_lengths,
                          copy=False)
branch_lengths[:] = np.array([0.1, 0.1, 0.3, 0.])

In [None]:
def log_like_with(branch_id: int, branch_length: float, grad=False):
    saved_branch_length = branch_lengths[branch_id]
    branch_lengths[branch_id] = branch_length
    if grad:
        _, log_grad = inst.branch_gradients()[0]
        result = np.array(log_grad)[branch_id]
    else:
        result = np.array(inst.log_likelihoods())[0]
    branch_lengths[branch_id] = saved_branch_length
    return result

x_vals = np.linspace(0, 0.3, 100)
df = pd.DataFrame({"x": x_vals, "y": [log_like_with(2, x) for x in x_vals]})
alt.Chart(df).mark_line().encode(
        alt.X("x"),
        alt.Y("y", scale=alt.Scale(zero=False))
    )

def phylo_log_like(x_arr):
    return np.array([log_like_with(2, x) for x in x_arr])

def phylo_log_like_grad(x_arr):
    return np.array([[log_like_with(2, x, grad=True)] for x in x_arr])

In [None]:
# Fake "phylogenetic" distributions.
if False:
    d = distributions.LogNormal(1)
    true_loc = np.array([-2.3])
    true_shape = np.array([-0.6])
    phylo_log_like = lambda x: d.log_prob(x, true_loc, true_shape)
    phylo_log_like_grad = lambda x: d.log_prob_grad(x, true_loc, true_shape)
if True:
    d = distributions.Gamma(1)
    alpha = np.array([2.])
    beta = np.array([5.])
    phylo_log_like = lambda x: d.log_prob(x, alpha, beta)
    phylo_log_like_grad = lambda x: d.log_prob_grad(x, alpha, beta) * np.exp(x)

In [None]:
# These are values for plotting and for calculating KL divergence.
x_vals = np.linspace(0.01, 0.7, 100)
x_vals_transpose = np.transpose(np.array([x_vals]))

transforms = {
    "identity": lambda x: x,
    "normalize": lambda x: x / np.sum(x),
    "exp": lambda x: np.exp(x),
    "exp_normalize": lambda x: np.exp(x) / sum(np.exp(x))
}

def plot_functions(f_true, f_approx, transform="identity"):
    transform = transforms[transform]
    data = pd.DataFrame({"x": x_vals, "truth": transform(f_true(x_vals_transpose)), 
                         "approx": transform(f_approx(x_vals_transpose))})
    return alt.Chart(data.melt(id_vars=["x"])).mark_line().encode(
        x='x',
        y='value',
        color='variable'
    )

def kl_div(f_true, f_approx, transform="identity"):
    transform = transforms[transform]
    f_approx_trans = lambda x: transform(f_approx(x))
    f_true_trans = lambda x: transform(f_true(x))
    return {
        "standard": stats.entropy(f_true_trans(x_vals_transpose), f_approx_trans(x_vals_transpose)),
        "reversed": stats.entropy(f_approx_trans(x_vals_transpose), f_true_trans(x_vals_transpose))}

In [None]:
branch_length_param_count = 1
sgd_server_args = {'loc': branch_length_param_count, 'shape': branch_length_param_count}
infer_opt = optimizers.SGD_Server(sgd_server_args)
stepsz = 0.01
clip = 100.
anneal_freq = 50
anneal_rate = 0.95

# Variational distribution
q = distributions.LogNormal(1)
loc = np.array([1.])
shape = np.array([0.5])
q_log_like = lambda x: q.log_prob(x, loc, shape)

In [None]:
def gradient_step(sample_count):
    global loc, shape
    x = q.sample(loc, shape, sample_count)
    weights = branch_length.like_weights(q, phylo_log_like(x), x, loc, shape, clip)
    loc_grad, shape_grad = branch_length.param_grad(q, weights, phylo_log_like_grad(x), x, loc, shape)
    if clip:
        loc_grad = np.clip(loc_grad, -clip, clip)
        shape_grad = np.clip(shape_grad, -clip, clip)
    update_dict = infer_opt.adam(stepsz_dict, {'loc': loc, 'shape': shape}, 
                                 {'loc': loc_grad, 'shape': shape_grad})
    loc += update_dict['loc']
    shape += update_dict['shape']
    return kl_div(phylo_log_like, q_log_like, transform="exp_normalize")

In [None]:
# Re-initialize the SGD server.
infer_opt.__init__(sgd_server_args)
stepsz_dict = {'loc': stepsz, 'shape': stepsz}
loc = np.array([-1.])
shape = np.array([1.])
results = [kl_div(phylo_log_like, q_log_like, transform="exp_normalize")]

for i in range(10):
    results.append(gradient_step(1))
    if i % anneal_freq == 0:
        for k in stepsz_dict:
            stepsz_dict[k] *= anneal_rate


plot_data = pd.DataFrame(results).reset_index()
    
alt.Chart(
    pd.melt(plot_data, id_vars=["index"], var_name="variant", value_name="KL divergence")
    ).mark_line().encode(
        alt.X("index"),
        alt.Y("KL divergence",
              scale=alt.Scale()),
        color="variant"
    )

In [None]:
plot_functions(phylo_log_like, q_log_like, transform="exp_normalize")

In [None]:
loc

In [None]:
shape