In [1]:
# dependencies

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.interpolate import PchipInterpolator
from scipy.special import logsumexp
from scipy.stats import norm

from infemus import emus

sns.set_theme()

In [32]:
# helper functions

def run_experiment(n_variates, bounds, l, n, y, q, t, rng):
    lams = np.linspace(*bounds, l + 1)
    lame = np.linspace(*bounds, l ** 2 + 1)
    log_us = np.log(norm.pdf(y, loc=lams, scale=np.sqrt(1/q + 1/t)) + norm.pdf(-y, loc=lams, scale=np.sqrt(1/q + 1/t)))
    log_ue = np.log(norm.pdf(y, loc=lame, scale=np.sqrt(1/q + 1/t)) + norm.pdf(-y, loc=lame, scale=np.sqrt(1/q + 1/t)))
    us = np.exp(log_us - logsumexp(log_us))
    ue = np.exp(log_ue - logsumexp(log_ue))
    return lame, ue, np.array([sample_emus_estimate(y, lams, lame, n, q, t, rng) for _ in range(n_variates)]), np.array([sample_gibbs_estimate(y, lams, lame, n, q, t, rng) for _ in range(n_variates)])

def sample_emus_estimate(y, lams, lame, n, q, t, rng):
    prior_prec = t
    post_prec = q + prior_prec
    marg_var = 1/q + 1/t
    p = [np.exp(norm.logpdf(y, loc=ls_, scale=np.sqrt(marg_var)) - np.logaddexp(norm.logpdf(y, loc=ls_, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=ls_, scale=np.sqrt(marg_var)))) for ls_ in lams]
    the_samples = [norm.rvs(loc=(rng.choice([-1, 1], p=[1-p_, p_], size=n)*q*y + ls_*prior_prec)/post_prec, scale=np.sqrt(1/post_prec), random_state=rng) for p_, ls_ in zip(p, lams)]
    log_psis = [norm.logpdf(the_[:, np.newaxis], loc=lams[np.newaxis], scale=np.sqrt(1/prior_prec)) for the_ in the_samples]
    log_psie = [norm.logpdf(the_[:, np.newaxis], loc=lame[np.newaxis], scale=np.sqrt(1/prior_prec)) for the_ in the_samples]
    us_est, _ = emus.eval_vardi_estimator(log_psis)
    ue_est = emus.extrapolate(log_psie, log_psis, us_est)
    return ue_est / np.sum(ue_est)

def sample_gibbs_estimate(y, ls, le, n, q, t, rng):
    prior_prec = t
    post_prec = q + prior_prec
    marg_var = 1/q + 1/t
    lat = [rng.choice(len(ls))]
    for _ in range(n * len(ls)):
        p = np.exp(norm.logpdf(y, loc=ls[lat[-1]], scale=np.sqrt(marg_var)) - np.logaddexp(norm.logpdf(y, loc=ls[lat[-1]], scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=ls[lat[-1]], scale=np.sqrt(marg_var))))
        the = norm.rvs(loc=(rng.choice([-1, 1], p=[1-p, p])*q*y + ls[lat[-1]]*prior_prec)/post_prec, scale=np.sqrt(1/post_prec), random_state=rng)
        log_psi = norm.logpdf(the, loc=ls, scale=np.sqrt(1/prior_prec))
        lat.append(rng.choice(np.arange(len(ls)), p=np.exp(log_psi - logsumexp(log_psi))))
    us_est = np.bincount(lat, minlength=len(ls)) / len(lat)
    if np.max(us_est) == 1:
        ue_est = np.where(le == le[np.argmin(np.abs(le - ls[lat[0]]))], 1, 0)
    else:
        ue_est = PchipInterpolator(ls, np.bincount(lat, minlength=len(ls)))(le)
    return ue_est / np.sum(ue_est)

In [33]:
# config

seed = 0
n_variates = 128
bounds = (-2, 2)
l = 16
n = 16
y = 1
q = 2 ** 6
t = 2 ** np.arange(15)

In [34]:
# run experiment

rng = np.random.default_rng(seed)
output = [run_experiment(n_variates, bounds, l, n, y, q, t_, rng) for t_ in t]

emus_err = [np.sum(np.mean(np.abs(u_emus_ - u_), 0)) for (_, u_, u_emus_, _) in output]
gibbs_err = [np.sum(np.mean(np.abs(u_gibbs_ - u_), 0)) for (_, u_, _, u_gibbs_) in output]

In [None]:
# draw figure

plt.figure(figsize=(4, 3))
plt.plot(t, emus_err)
plt.plot(t, gibbs_err)

plt.xscale('log')
plt.xlabel(r'$\tau$')
plt.ylabel('error')