In [None]:
# dependencies

from itertools import product

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.special import logsumexp
from scipy.stats import norm, multivariate_normal, gaussian_kde

from umbrella import emus
from umbrella.models import xfx

sns.set()

In [None]:
# config 

n_windows = 16
n_samples = 2
n_burnin = 2
n_levels = np.array([16, 16])
alp0 = 0
tau0 = 1
lam = 1
df_tau = 1e64
scale_tau = 1e-64
seed = 0
bounds1 = [-1, 1]
bounds2 = [0, 2]

In [None]:
# design

def sample_coef_fixture(j, tau, ome):
    alp = [ome.normal(0, 1 / np.sqrt(tau_), j_) for tau_, j_ in zip(tau, j)]
    return [alp_ - np.mean(alp_) for alp_ in alp]

def sample_randfx_fixture(i, df_tau, scale_tau, ome):
    tau = scale_tau * ome.chisquare(df_tau, len(i))
    alp = sample_coef_fixture(i, tau, ome)
    return alp, tau

def sample_data_fixture(i, n_inflator, alp0, alp, lam, ome):
    y = ome.normal((alp0 + np.sum([alp_[j_] for alp_, j_ in zip(alp, i.T)], 0))[:, np.newaxis],
                   1 / np.sqrt(lam), size=(i.shape[0], n_inflator))
    y1 = np.sum(y, 1)
    y2 = np.sum(np.square(y), 1)
    n = np.repeat(n_inflator, i.shape[0])
    return y1, y2, n

def sample_balanced_design(j, ome):
    n = 1
    for j_ in j:
        n = np.lcm(n, j_)
    i = np.array([np.repeat(np.arange(j_), n / j_) for j_ in j]).T
    for k_ in range(len(j)):
        ome.shuffle(i[:, k_])
    return i

def sample_balanced_fixture(j, alp0=0, lam=1, df_tau=2, scale_tau=1, n_inflator=1, ome=np.random.default_rng()):
    alp, tau = sample_randfx_fixture(j, df_tau, scale_tau, ome)
    i = sample_balanced_design(j, ome)
    y1, y2, n = sample_data_fixture(i, n_inflator, alp0, alp, lam, ome)
    return (y1, y2, n, j, i), (alp0, alp, tau, lam)

def sample_mar_design(j, p_miss, ome):
    i = np.stack(np.meshgrid(*[np.arange(j_) for j_ in j])).T.reshape(-1, 2)
    i = i[ome.uniform(size=i.shape[0]) > p_miss]
    ome.shuffle(i, 0)
    return i

def sample_mar_fixture(j, df_tau=2, scale_tau=1, p_miss=.1, ome=np.random.default_rng()):
    alp0 = 0
    alp, tau = sample_randfx_fixture(j, df_tau, scale_tau, ome)
    i = sample_mar_design(j, p_miss, ome)
    eta = alp0 + np.sum([alp_[j_] for alp_, j_ in zip(alp, i.T)], 0)
    return (eta, i), (alp0, alp, tau)

In [None]:
# sampling

def eval_logmargin(y1, n, j, i, tau0, tau, lam):
    s22 = np.diag(1 / (lam * n)) + 1 / tau0
    for k_ in range(len(j)):
        for j_ in range(j[k_]):
            s22[np.ix_(i[:, k_] == j_, i[:, k_] == j_)] += 1 / tau[k_]
    return multivariate_normal.logpdf(y1 / n, np.zeros_like(y1), s22) 

def eval_logprior(alp0, alp, tau0, tau):
    log_prior = np.sum([-(len(alp_) * np.log(2 * np.pi / tau_) + tau_ * np.sum(np.square(alp_))) / 2 for alp_, tau_ in zip([[alp0]] + alp, [tau0] + list(tau))])
    return log_prior

def sample_cond(w1, w2, x0, x1, x2, j, tau0, tau, lam, ome):
    alp0 = 0
    alp = [np.zeros(j_) for j_ in j]
    while True:
        alp = xfx.update_coefs(x1, x2, None, alp, tau0, tau, lam, ome)
        alp0 = xfx.update_intercept(x0, x1, alp, tau0, lam, ome)
        log_prior1 = [eval_logprior(alp0, alp, tau0, tau_) for tau_ in w1]
        log_prior2 = [eval_logprior(alp0, alp, tau0, tau_) for tau_ in w2]
        yield np.array(log_prior1), np.array(log_prior2)

def sample_joint(w1, x0, x1, x2, j, tau0, lam, ome):
    alp0 = 0
    alp = [np.zeros(j_) for j_ in j]
    idx = ome.choice(len(w1))
    tau = w1[idx]
    while True:
        alp = xfx.update_coefs(x1, x2, None, alp, tau0, tau, lam, ome)
        alp0 = xfx.update_intercept(x0, x1, alp, tau0, lam, ome)
        log_prior = [eval_logprior(alp0, alp, tau0, tau_) for tau_ in w1]
        idx = ome.choice(len(w1), p=np.exp(np.array(log_prior) - logsumexp(log_prior)))
        tau = w1[idx]
        yield idx
    
def est_margin_emus(w1, w2, y1, y2, n, j, i, tau0, lam, ome, n_samples, n_burnin):
    x0, x1, x2 = xfx.reduce_data(y1, y2, n, i)
    samplers = [sample_cond(w1, w2, x0, x1, x2, j, tau0, w_, lam, ome) for w_ in w1]
    for sampler in samplers:
        for _ in zip(range(n_burnin), sampler):
            continue
    log_psi = [list(zip(range(n_samples), sampler)) for sampler in samplers]
    log_psi1 = [np.array([log_psi1__ for _, (log_psi1__, _) in log_psi_]) for log_psi_ in log_psi]
    log_psi2 = [np.array([log_psi2__ for _, (_, log_psi2__) in log_psi_]) for log_psi_ in log_psi]
    z1 = emus.eval_vardi_estimator(log_psi1)[0]
    z2 = emus.extrapolate(log_psi2, log_psi1, z1)
    return z2 / np.sum(z2)
    
def est_margin_griddy(w1, w2, y1, y2, n, j, i, tau0, lam, ome, n_samples, n_burnin):
    x0, x1, x2 = xfx.reduce_data(y1, y2, n, i)
    sampler = sample_joint(w1, x0, x1, x2, j, tau0, lam, ome)
    for _, idx_ in zip(range(n_burnin * len(w1)), sampler):
        continue
    idx = [idx_]
    for _, idx_ in zip(range(n_samples * len(w1)), sampler):
        idx.append(idx_)
    z2 = gaussian_kde(np.log(np.array(w1)[idx].T))(np.log(np.array(w2).T))
    return z2 / np.sum(z2)
    #return np.bincount(idx, minlength=len(w1)) / (n_samples * len(w1))

In [None]:
# sample data

ome = np.random.default_rng(seed)
data, params = sample_mar_fixture(n_levels, df_tau, scale_tau, 0, ome)
y1 = ome.normal(data[0], np.sqrt(1/lam))
#data = (y1, np.square(y1), np.ones_like(y1), n_levels, data[-1])

In [None]:
# generate grid# generate grid

wa1 = np.exp(np.linspace(*bounds1, n_windows))
wb1 = np.exp(np.linspace(*bounds2, n_windows))
w1 = [np.array([lam1_, lam2_]) for (lam1_, lam2_) in product(wa1, wb1)]

wa2 = np.exp(np.linspace(*bounds1, n_windows * 2))
wb2 = np.exp(np.linspace(*bounds2, n_windows * 2))
w2 = [np.array([lam1_, lam2_]) for (lam1_, lam2_) in product(wa2, wb2)]

x_rec, y_rec = np.meshgrid(wa2, wb2)

In [None]:
# evaluate ground truth

log_z = [eval_logmargin(y1, np.ones_like(y1), n_levels, data[1], tau0, w_, lam) for w_ in w2]
z = np.exp(np.array(log_z) - logsumexp(log_z))
z_rec = np.reshape(z, 2 * (int(np.sqrt(len(z))),))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.contourf(x_rec, y_rec, z_rec, 8, cmap='magma')
ax2.pcolormesh(x_rec, y_rec, z_rec, cmap='magma')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax2.set_xscale('log')
ax2.set_yscale('log')
ax1.set_xlabel(r'$\lambda_{1}$')
ax1.set_ylabel(r'$\lambda_{2}$')

In [None]:
# estimate by emus

z_emus = est_margin_emus(w1, w2, y1, y1 ** 2, np.ones_like(y1), n_levels, data[1], tau0, lam, ome, n_samples, n_burnin)
z_rec_emus = np.reshape(z_emus, 2 * (int(np.sqrt(len(z_emus))),))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.contourf(x_rec, y_rec, z_rec_emus, 8, cmap='magma')
ax2.pcolormesh(x_rec, y_rec, z_rec_emus, cmap='magma')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax2.set_xscale('log')
ax2.set_yscale('log')
ax1.set_xlabel(r'$\lambda_{1}$')
ax1.set_ylabel(r'$\lambda_{2}$')

In [None]:
# estimate by griddy

z_griddy = est_margin_griddy(w1, w2, y1, y1 ** 2, np.ones_like(y1), n_levels, data[1], tau0, lam, ome, n_samples, 2 * n_burnin)
z_rec_griddy = np.reshape(z_griddy, 2 * (int(np.sqrt(len(z_griddy))),))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.contourf(x_rec, y_rec, z_rec_griddy, 8, cmap='magma')
ax2.pcolormesh(x_rec, y_rec, z_rec_griddy, cmap='magma')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax2.set_xscale('log')
ax2.set_yscale('log')
ax1.set_xlabel(r'$\lambda_{1}$')
ax1.set_ylabel(r'$\lambda_{2}$')

In [None]:
# error distribution of respective methods

n_sims = 16
z_emus_sim = np.array([est_margin_emus(w1, w2, y1, y1 ** 2, np.ones_like(y1), n_levels, data[1], tau0, lam, ome, n_samples, n_burnin) for _ in range(n_sims)])
z_griddy_sim = np.array([est_margin_griddy(w1, w2, y1, y1 ** 2, np.ones_like(y1), n_levels, data[1], tau0, lam, ome, n_samples, n_burnin) for _ in range(n_sims)])

In [None]:
# marginal error plots

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.plot(wa2, np.sum(z_rec, 0), color='black')
for z_marg1 in [np.sum(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 0) for z_ in z_emus_sim]:
    ax1.plot(wa2, z_marg1, color=sns.color_palette('colorblind')[0], alpha=.25)
for z_marg1 in [np.sum(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 0) for z_ in z_griddy_sim]:
    ax1.plot(wa2, z_marg1, color=sns.color_palette('colorblind')[1], alpha=.25)
ax1.set_xlabel(r'$\lambda_{1}$')
ax1.set_xscale('log')

ax2.plot(wb2, np.sum(z_rec, 1), color='black')
for z_marg1 in [np.sum(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 1) for z_ in z_emus_sim]:
    ax2.plot(wb2, z_marg1, color=sns.color_palette('colorblind')[0], alpha=.25)
for z_marg1 in [np.sum(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 1) for z_ in z_griddy_sim]:
    ax2.plot(wb2, z_marg1, color=sns.color_palette('colorblind')[1], alpha=.25)
ax2.set_xlabel(r'$\lambda_{2}$')
ax2.set_xscale('log')

In [None]:
# profile error plots

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.plot(wa2, np.max(z_rec, 0), color='black')
for z_marg1 in [np.max(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 0) for z_ in z_emus_sim]:
    ax1.plot(wa2, z_marg1, color=sns.color_palette('colorblind')[0], alpha=.25)
for z_marg1 in [np.max(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 0) for z_ in z_griddy_sim]:
    ax1.plot(wa2, z_marg1, color=sns.color_palette('colorblind')[1], alpha=.25)
ax1.set_xlabel(r'$\lambda_{1}$')
ax1.set_xscale('log')

ax2.plot(wb2, np.max(z_rec, 1), color='black')
for z_marg1 in [np.max(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 1) for z_ in z_emus_sim]:
    ax2.plot(wb2, z_marg1, color=sns.color_palette('colorblind')[0], alpha=.25)
for z_marg1 in [np.max(np.reshape(z_, 2 * (int(np.sqrt(len(z_))),)), 1) for z_ in z_griddy_sim]:
    ax2.plot(wb2, z_marg1, color=sns.color_palette('colorblind')[1], alpha=.25)
ax2.set_xlabel(r'$\lambda_{2}$')
ax2.set_xscale('log')

In [None]:
sns.boxplot(data=pd.DataFrame({'emus': np.linalg.norm(z_emus_sim - z, 1, 1), 'griddy': np.linalg.norm(z_griddy_sim - z, 1, 1)}).melt(var_name='algo'), x='algo', y='value')