In [57]:
import functools
import pickle

import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from jax import vmap
from numpyro import handlers

In [None]:
%load_ext autoreload
%autoreload 2

from experiments.src.model import BNNRegressor

In [15]:
import matplotlib
matplotlib.use("nbAgg")

In [88]:
with open("/tmp/hmc/hmc-samples.pkl", "rb") as f:
    samples = pickle.load(f)
samples['w'].shape

(4, 1000, 1746)

In [89]:
with open("/tmp/hmc/hmc-preds.pkl", "rb") as f:
    preds = pickle.load(f)
preds["Y"].shape

(4, 1000, 500, 1)

In [83]:
mean_means = preds["Y_mean"].mean(axis=1)
mean_percentiles = jnp.percentile(preds["Y_mean"], q=jnp.array([5.0, 95.0]), axis=1).squeeze()

In [84]:
full_percentiles = jnp.percentile(preds["Y"], q=jnp.array([5.0,95.0]), axis=1).squeeze()

In [13]:
with open('/tmp/trdata.pkl', "rb") as f:
    train = pickle.load(f)

(DeviceArray([[ 1.        , -1.        ,  1.        ],
              [ 1.        , -0.9877551 ,  0.9756602 ],
              [ 1.        , -0.97551024,  0.9516202 ],
              [ 1.        , -0.96326536,  0.92788017],
              [ 1.        , -0.95102036,  0.90443975],
              [ 1.        , -0.9387755 ,  0.8812994 ],
              [ 1.        , -0.9265306 ,  0.85845894],
              [ 1.        , -0.9142857 ,  0.83591837],
              [ 1.        , -0.90204084,  0.81367767],
              [ 1.        , -0.88979596,  0.79173684],
              [ 1.        , -0.877551  ,  0.77009577],
              [ 1.        , -0.86530614,  0.74875474],
              [ 1.        , -0.8530612 ,  0.7277134 ],
              [ 1.        , -0.8408163 ,  0.70697206],
              [ 1.        , -0.82857144,  0.68653065],
              [ 1.        , -0.81632656,  0.66638905],
              [ 1.        , -0.8040817 ,  0.6465473 ],
              [ 1.        , -0.79183674,  0.6270054 ],
          

In [52]:
with open("/tmp/tsdata.pkl", "rb") as f:
    test = pickle.load(f)

In [85]:
t = jnp.linspace(-1.7, 1.7, 500)

In [87]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 7), sharex=True, sharey=True)
for i, ax in enumerate(axs.ravel()):
    ax.plot(t, mean_means[i])
    ax.fill_between(t, mean_percentiles[0, i], mean_percentiles[1, i], alpha=0.5, color="orange")
    ax.fill_between(t, full_percentiles[0, i], full_percentiles[1, i], alpha=0.5, color="lightgreen")
    ax.plot(train[0][:50, 1], train[1][:50], "kx")
    ax.set_ylim(-6, +6)
fig.tight_layout()
fig.savefig("figs/hmc-by-chain-first-half.png",dpi=480)

<IPython.core.display.Javascript object>

# Prior / likelihood calculation for samples

In [36]:
BETA = 1.0
D_X = 3
BNN_SIZE = [32, 32, 16]

In [38]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=1,
)

In [106]:
posterior_samples = dict(w=samples['w'][:, -700, :])

In [122]:
def loglik(rng_key, params, X_train, Y_train, X_test):
    rng_lik, rng_gen = random.split(rng_key)
    model = handlers.substitute(handlers.seed(bnn, rng_lik), params)
    trace = handlers.trace(model).get_trace(X=X_train, Y=Y_train)
    # print(numpyro.util.format_shapes(trace, compute_log_prob=True))
    w_node = trace["w"]
    y_node = trace["Y"]

    model = handlers.substitute(handlers.seed(bnn, rng_gen), params)
    trace = handlers.trace(model).get_trace(X=X_test, Y=None)
    return dict(prior_logprob=w_node['fn'].log_prob(w_node['value']),
                loglik=y_node['fn'].log_prob(y_node['value']).sum(),
                Y_mean=trace["Y_mean"]["value"].squeeze(),
                Y_scale=trace["Y_scale"]["value"].squeeze())

In [123]:
loglik_full = functools.partial(loglik, X_train=train[0], Y_train=train[1], X_test=test)

In [124]:
loglik_first_half = functools.partial(loglik, X_train=train[0][:50], Y_train=train[1][:50], X_test=test)

In [125]:
loglik_second_half = functools.partial(loglik, X_train=train[0][50:], Y_train=train[1][50:], X_test=test)

In [126]:
ll_full = vmap(loglik_full)(random.split(random.PRNGKey(1), 4), posterior_samples)

In [127]:
ll_first_half = vmap(loglik_first_half)(random.split(random.PRNGKey(1), 4), posterior_samples)

In [128]:
ll_second_half = vmap(loglik_second_half)(random.split(random.PRNGKey(1), 4), posterior_samples)

In [131]:
ll_full['loglik'], ll_first_half['loglik'], ll_second_half['loglik']

(DeviceArray([28.59112 , 29.961697, 28.041077, 25.048267], dtype=float32),
 DeviceArray([7.6411657, 9.588444 , 8.756374 , 7.7478023], dtype=float32),
 DeviceArray([20.949963, 20.373266, 19.28472 , 17.300488], dtype=float32))

In [132]:
fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(9, 8), sharex='all', sharey='all')
for i, ax in enumerate(axs.ravel()):
    ax.plot(train[0][:, 1], train[1][:, 0], 'kx', alpha=0.2)
    ax.plot(test[:, 1], ll_full['Y_mean'][i])
    ax.fill_between(test[:, 1], ll_full['Y_mean'][i] - ll_full['Y_scale'][i] * 2,
                    ll_full['Y_mean'][i] + ll_full['Y_scale'][i] * 2, alpha=0.2)
    ax.set_ylim(-10, 10)
    ax.set_title(f"logprior={ll_full['prior_logprob'][i]:.3f}\n"
                 f"loglik first/second half={ll_first_half['loglik'][i]:.3f} / {ll_full['loglik'][i] - ll_first_half['loglik'][i]:.3f}\n"
                 f"logjoint={ll_full['loglik'][i] + ll_full['prior_logprob'][i]:.3f}")
fig.tight_layout()
fig.show()

<IPython.core.display.Javascript object>

In [133]:
fig.savefig("figs/loglik-two-part.png")