In [7]:
# dependencies

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.special import expit
from scipy.stats import multivariate_normal

from infemus import infemus
from infemus.models import logitx

sns.set_theme()

In [None]:
# config

n_variates = 128
n_samples = 1
n_burnin = 8
l = 16
seed = 1
tau0 = 1

In [8]:
# DGP

def sample_coef_fixture(j, tau, rng):
    alp = [rng.normal(0, 1 / np.sqrt(tau_), j_) for tau_, j_ in zip(tau, j)]
    return [alp_ - np.mean(alp_) for alp_ in alp]

def sample_randfx_fixture(i, df_tau, scale_tau, rng):
    tau = scale_tau * rng.chisquare(df_tau, len(i))
    alp = sample_coef_fixture(i, tau, rng)
    return alp, tau

def sample_mar_design(j, p_miss, rng):
    i = np.stack(np.meshgrid(*[np.arange(j_) for j_ in j])).T.reshape(-1, 2)
    i = i[rng.uniform(size=i.shape[0]) > p_miss]
    rng.shuffle(i, 0)
    return i

def sample_mar_fixture(j, df_tau=2, scale_tau=1, p_miss=.9, rng=np.random.default_rng()):
    alp0 = 0
    alp, tau = sample_randfx_fixture(j, df_tau, scale_tau, rng)
    i = sample_mar_design(j, p_miss, rng)
    eta = alp0 + np.sum([alp_[j_] for alp_, j_ in zip(alp, i.T)], 0)
    return (eta, i), (alp0, alp, tau)

def sample_balanced_design(j, rng):
    n = 1
    for j_ in j:
        n = np.lcm(n, j_)
    i = np.array([np.repeat(np.arange(j_), n / j_) for j_ in j]).T
    rng.shuffle(i, 0)
    return i

def sample_balanced_fixture(j, alp0=0, df_tau=2, scale_tau=1, rng=np.random.default_rng()):
    alp, tau = sample_randfx_fixture(j, df_tau, scale_tau, rng)
    i = sample_balanced_design(j, rng)
    eta = alp0 + np.sum([alp_[j_] for alp_, j_ in zip(alp, i.T)], 0)
    return (eta, i), (alp0, alp, tau)

In [9]:
# sampling

def eval_logmargin(y1, n, j, i, tau0, tau, lam):
    s22 = np.diag(1 / (lam * n)) + 1 / tau0
    for k_ in range(len(j)):
        for j_ in range(j[k_]):
            s22[np.ix_(i[:, k_] == j_, i[:, k_] == j_)] += 1 / tau[k_]
    return multivariate_normal.logpdf(y1 / n, np.zeros_like(y1), s22)

def eval_logprior(data, params, hyper):
    tau0 = data[-1]
    alp0, alp = params
    tau = hyper
    log_prior = np.sum([(len(alp_) * np.log(tau_ / (2 * np.pi)) - tau_ * np.sum(np.square(alp_))) / 2 for alp_, tau_ in zip([[alp0]] + alp, [tau0] + list(tau))])
    log_hyperprior = -np.sum(tau + np.log(tau)) / 2
    return log_prior + log_hyperprior

def cond_sampler(data, hyper, rng):
    y, n, j, i, tau0 = data
    return logitx.sample_posterior(y, n, j, i, tau0, hyper, True, rng)

def est_emus(lams, lame, y, n, j, i, tau0, rng, n_samples, n_burnin):
    return infemus.est_mlik(lams, lame, (y, n, j, i, tau0), cond_sampler, eval_logprior, rng, n_samples, n_burnin)

def est_rmse(l, d, tau0, rng, n_variates, n_samples, n_burnin):
    bds_t1 = (-1, 1)
    bds_t2 = (-1, 1)
    data = sample_mar_fixture(np.repeat(d, 2), 1e64, 1e-64, .5, rng)[0]
    y1, _, n, j, i = (expit(data[0]) > rng.uniform(size=len(data[0])), None, np.ones_like(data[0]), np.repeat(d, 2), data[1])
    lams_t1, lams_t2 = (np.exp(np.linspace(*bds, l + 1) * 8 / np.sqrt(d)) for bds in (bds_t1, bds_t2))
    lame_t1, lame_t2 = (np.exp(np.linspace(*bds, l * 2 + 1) * 8 / np.sqrt(d)) for bds in (bds_t1, bds_t2))
    lams = np.array(list(product(lams_t1, lams_t2)))
    lame = np.array(list(product(lame_t1, lame_t2)))
    u_est = np.array([est_emus(lams, lame, y1, n, j, i, tau0, rng, n_samples, n_burnin) for _ in range(n_variates)])
    u_gt = est_emus(lams, lame, y1, n, j, i, tau0, rng, n_samples * n_variates, n_burnin)
    return (lame, u_gt, u_est)

In [26]:
# simulation 

rng = np.random.default_rng(seed)
d = 10 ** np.arange(1, 4)
lam, u_gt, u_est = zip(*[est_rmse(l, d_, tau0, rng, n_variates, n_samples, n_burnin) for d_ in d])

In [None]:
error = [np.mean(np.linalg.norm(u_emus_ - u_, 2, 1)) for u_emus_, u_ in zip(u_est, u_gt)]

plt.figure(figsize=(4, 3))
plt.plot(d, error, marker='o')
plt.xscale('log', base=10)
plt.xlabel(r'$d$')
plt.ylabel('error')

In [None]:
plt.figure(figsize=(4, 3))
for lam_, u_, col in zip(lam, u_gt, sns.color_palette('flare', 3)):
    x = np.array(np.reshape([lam__[0] for lam__ in lam_], 2 * (int(np.sqrt(len(u_gt[0]))),)).T[0])
    zx = np.sum(np.reshape(u_, 2 * (int(np.sqrt(len(u_))),)).T, 1)
    dens = np.log(x)[-1] - np.log(x)[0]
    plt.fill_between(x, zx / dens, alpha=.5, color=col)
plt.xscale('log')
plt.xlabel(r'$d$')
plt.ylabel('density')