In [10]:
# dependencies

from functools import reduce
from itertools import islice

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import quad_vec
from scipy.special import logsumexp
from scipy.stats import norm

from infemus import emus

sns.set_theme()

In [11]:
# config

seed = 0

y = 1
q = 2 ** 5
tau = 2 ** 5

n_windows0 = 8
n_interp = 32
n_iter = 8
n_variates = 32
n_samples = 8
bds = (-6, 6)

In [12]:
# quadrature

def eval_logmargin(lam, y, q, tau):
    marg_var = 1/q + 1/tau
    log_u = np.logaddexp(norm.logpdf(y, loc=lam, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=lam, scale=np.sqrt(marg_var)))
    return log_u

def quad_weights(lame, y, q, tau):
    prior_prec = tau
    post_prec = q + tau
    marg_var = 1/q + 1/tau
    f = eval_f(lame, y, prior_prec, post_prec, marg_var)
    xi = eval_xi(f, lame, y, prior_prec, post_prec, marg_var)
    u_quad, _, f_inv = emus.solve_emus_system(f, np.ones_like(f[0]), 100)
    score = np.array([np.trace(u_quad[i] ** 2 * f_inv.T @ xi[i] @ f_inv) for i in range(len(xi))])
    trunc_score = np.sqrt(np.where(score >= 0, score, 0))
    w_quad = trunc_score / np.sum(trunc_score)
    return u_quad, w_quad

def eval_f(lam, y, prior_prec, post_prec, marg_var):
    return np.array([eval_1mom(lam_, lam, y, prior_prec, post_prec, marg_var) for lam_ in lam])

def eval_xi(f, lam, y, prior_prec, post_prec, marg_var):
    return [eval_2mom(lam[i], lam, y, prior_prec, post_prec, marg_var) - np.outer(f_, f_) for i, f_ in enumerate(f)]

def eval_1mom(lam_, lam, y, prior_prec, post_prec, marg_var):
    def f(the):
        return eval_dens(the, lam_, p_, y, prior_prec, post_prec) * eval_integ(the, lam, prior_prec)
    p_ = np.exp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)) - np.logaddexp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=lam_, scale=np.sqrt(marg_var))))
    return quad_vec(f, -np.inf, np.inf)[0]

def eval_2mom(lam_, lam, y, prior_prec, post_prec, marg_var):
    def f(the):
        phi = eval_integ(the, lam, prior_prec)
        return eval_dens(the, lam_, p_, y, prior_prec, post_prec) * np.outer(phi, phi).flatten()
    p_ = np.exp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)) - np.logaddexp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=lam_, scale=np.sqrt(marg_var))))
    return np.reshape(quad_vec(f, -np.inf, np.inf)[0], (len(lam), len(lam)))

def eval_dens(the, lam_, p_, y, prior_prec, post_prec):
    return p_ * norm.pdf(the, loc=(q*y + lam_*prior_prec)/post_prec, scale=np.sqrt(1/post_prec)) + (1-p_) * norm.pdf(the, loc=(-q*y + lam_*prior_prec)/post_prec, scale=np.sqrt(1/post_prec))

def eval_integ(the, lam, prior_prec):
    log_psi = norm.logpdf(the, loc=lam, scale=np.sqrt(1/prior_prec))
    return np.exp(log_psi - logsumexp(log_psi))

In [13]:
# sampling

def eval_logmargin(lam, y, q, tau):
    marg_var = 1/q + 1/tau
    log_u = np.logaddexp(norm.logpdf(y, loc=lam, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=lam, scale=np.sqrt(marg_var)))
    return log_u

def resample(lams, lame, y, q, tau, n_samples, ome):
    prior_prec = tau
    post_prec = q + prior_prec
    marg_var = 1/q + 1/tau
    p = [np.exp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)) - np.logaddexp(norm.logpdf(y, loc=lam_, scale=np.sqrt(marg_var)), norm.logpdf(-y, loc=lam_, scale=np.sqrt(marg_var)))) for lam_ in lams]
    the_samples = [norm.rvs(loc=(ome.choice([-1, 1], p=[1-p_, p_], size=n_samples)*q*y + lam_*prior_prec)/post_prec, scale=np.sqrt(1/post_prec), random_state=ome) for p_, lam_ in zip(p, lams)]
    log_psi = [norm.logpdf(the_[:, np.newaxis], loc=lame[np.newaxis], scale=np.sqrt(1/prior_prec)) for the_ in the_samples]
    return log_psi

def acquire_windows(n_new, f, ome):
    weights = f / np.sum(f)
    return subsample_pivotal(weights, n_new, ome)

def subsample_stratified(p, n, ome):
    u = [ome.uniform(i/n, (i+1)/n) for i in range(n)]
    return np.searchsorted(np.cumsum(p), u)

def subsample_pivotal(p, n, ome):
    t = np.arange(len(p))
    s = []
    r = []
    q = np.copy(p) * n
    t = t[p > 0]
    q = q[p > 0]
    while len(q) > 1:
        for i in range(0, len(q), 2):
            if i + 1 == len(q):
                s.append(i)
                break
            if q[i] + q[i+1] <= 1:
                if ome.uniform() <= q[i] / (q[i] + q[i+1]):
                    q[i] += q[i+1]
                    q[i+1] = 0
                    s.append(i)
                else:
                    q[i+1] += q[i]
                    q[i] = 0
                    s.append(i+1)
            else:
                if ome.uniform() <= (1 - q[i]) / (2 - q[i] - q[i+1]):
                    q[i] = q[i] + q[i+1] - 1
                    q[i+1] = 1
                    s.append(i)
                    r.append(t[i+1])
                else:
                    q[i+1] = q[i] + q[i+1] - 1
                    q[i] = 1
                    s.append(i+1)
                    r.append(t[i])
        q = q[s]
        t = t[s]
        s = []
    return np.union1d(r, t)

def eval_weights(log_psi):
    is_sampled = [(log_psi_.shape[0] > 0) for log_psi_ in log_psi]
    log_psis = [log_psi_[:, is_sampled] for log_psi_ in log_psi if log_psi_.shape[0] > 0]
    log_psie = [log_psi_ for log_psi_ in log_psi if log_psi_.shape[0] > 0]
    log_kaps = [log_psi_ - logsumexp(log_psi_[:, is_sampled], 1)[:, np.newaxis] for log_psi_ in log_psie]
    log_kape = [log_psi_ - logsumexp(log_psi_, 1)[:, np.newaxis] for log_psi_ in log_psie]
    us = emus.eval_vardi_estimator(log_psis)[0]
    log_phi1 = (logsumexp(log_phi_s_[:, :, np.newaxis] + log_phi_e_[:, np.newaxis, :], 0) - np.log(log_phi_s_.shape[0]) + np.log(z_) for z_, log_phi_s_, log_phi_e_ in zip(us, log_kaps, log_kape))
    log_f_fill = reduce(np.logaddexp, log_phi1)
    f_fill = np.exp(log_f_fill - logsumexp(log_f_fill, 1)[:,np.newaxis])
    log_phi2 = (logsumexp(log_phi_s_[:, :, np.newaxis, np.newaxis] + log_phi_e1_[:, np.newaxis, :, np.newaxis] + log_phi_e2_[:, np.newaxis, np.newaxis, :], 0) - np.log(log_phi_s_.shape[0]) + np.log(z_) for z_, log_phi_s_, log_phi_e1_, log_phi_e2_ in zip(us, log_kaps, log_kape, log_kape))
    log_f2_fill = reduce(np.logaddexp, log_phi2)
    f2_fill = np.exp(log_f2_fill - logsumexp(log_f2_fill, (1, 2))[:,np.newaxis,np.newaxis])
    xi = [f2_filled_ - np.outer(f_filled_, f_filled_) for f2_filled_, f_filled_ in zip(f2_fill, f_fill)]
    f_filled_inv = emus.solve_emus_system(f_fill, np.ones(len(log_psi)), 100)[2]
    ue = emus.extrapolate(log_psie, log_psis, us)
    score = np.array([np.trace(ue[i] ** 2 * f_filled_inv.T @ xi[i] @ f_filled_inv) for i in range(len(xi))])
    trunc_score = np.sqrt(np.where(score >= 0, score, 0))
    we = trunc_score / np.sum(trunc_score)
    return ue, we

def adjust_weights(n_old, n_new, w_new):
    score = (n_new + sum(n_old)) * w_new - n_old
    trunc_score = np.where(score > 0, score, 0)
    return trunc_score / np.sum(trunc_score)

def est_emus_iter(ixn0, lame, y, q, tau, n_samples, ome):
    log_psi = [np.empty(shape=(0, len(lame))) for _ in lame]
    ixn = ixn0
    while True:
        log_psi_new = resample(lame[ixn], lame, y, q, tau, n_samples, ome)
        for i, ixn_ in enumerate(ixn):
            log_psi[ixn_] = np.append(log_psi[ixn_], log_psi_new[i], axis=0)
        u, w = eval_weights(log_psi)
        w_smooth = np.sqrt(w) / np.sum(np.sqrt(w))
        w_adj = adjust_weights(np.array([log_psi_.shape[0] for log_psi_ in log_psi]), n_samples, w_smooth)
        ixn = acquire_windows(len(ixn0), w_adj, ome)
        yield u, w

In [14]:
# evaluate ground truth

rng = np.random.default_rng(seed)
lam = np.linspace(*bds, n_windows0 * n_interp + 1)
u_quad, w_quad = quad_weights(lam, y, q, tau)

log_u = eval_logmargin(lam, y, q, tau)
u = np.exp(np.array(log_u) - logsumexp(log_u))

In [None]:
# draw sampling intervals for both methods [fig 10a]

ixsfix = np.arange(len(lam))[::n_interp]
u_est_flex, w_est_flex = zip(*[next(islice(est_emus_iter(ixsfix, lam, y, q, tau, n_samples, rng), n_iter - 1, n_iter)) for _ in range(n_variates)])
u_est_fix, _ = zip(*[next(islice(est_emus_iter(ixsfix, lam, y, q, tau, n_samples * n_iter, rng), 0, 1)) for _ in range(n_variates)])

plt.figure(figsize=(4, 3))
plt.plot(lam, u, color='black')
plt.fill_between(lam, *np.percentile(np.vstack([u_est_flex, np.array(u_est_flex)[:,::-1]]), [12.5, 87.5], 0), alpha=.5)
plt.fill_between(lam, *np.percentile(np.vstack([u_est_fix, np.array(u_est_fix)[:,::-1]]), [12.5, 87.5], 0), alpha=.5)
plt.xlim([-3, 3])
plt.xlabel(r'$\lambda$')
plt.ylabel(r'$u(\lambda)$')

In [None]:
# draw optimal intensity function [fig 10b]

plt.figure(figsize=(4, 3))
plt.plot(lam, w_quad, color='black')
plt.fill_between(lam, *np.percentile(np.vstack([w_est_flex, np.array(w_est_flex)[:,::-1]]), [12.5, 87.5], 0), alpha=.5)
plt.xlim([-3, 3])
plt.xlabel(r'$\lambda$')
plt.ylabel(r'$w(\lambda)$')