In [1]:
# dependencies

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from scipy.special import logsumexp

from infemus import infemus
from infemus.models import lgp_iid
from infemus.tools.metropolis_mv import eval_norm_prec

sns.set_theme()

In [2]:
# config

seed = 2

# algo settings

n_variates = 128
bds_t1 = (0, 3.5)
bds_t2 = (-2, 2)
m = 32
hires_regime = [(16, 1), (16, 4), (16, 16), (16, 64)]
infill_regime = [(4, 16), (8, 16), (16, 16), (32, 16)]
lores_regime = [(4, 16), (4, 64), (4, 256), (4, 1024), (4, 4096)]

# DGP settings

n_obs = 32
n_dim = 1
mu = 0
t1 = 1
t2 = 1
phi = .25
sd_out = 1.5

In [3]:
# DGP

def gen_inputs(n_obs, n_dim, mu, t1, t2, phi, sd_out, rng):
    x = rng.standard_normal(size=(n_obs, n_dim))
    mean, cov, eig_cov = gen_suff(x, mu, t1, t2)
    the = rng.multivariate_normal(mean, cov)
    y = rng.normal(the, np.where(np.sum(np.abs(x), 1) > sd_out, 10 * phi, phi))
    xc = np.linspace(-2.5, 2.5, int(1e3))[:, np.newaxis]
    cov1 = t1 / t2 * np.exp(-np.square(cdist(xc, xc)) * t2 / x.shape[1])
    cov12 = t1 / t2 * np.exp(-np.square(cdist(xc, x)) * t2 / x.shape[1])
    meanc = cov12 @ np.linalg.inv(cov) @ the
    covc = cov1 - cov12 @ np.linalg.inv(cov) @ cov12.T
    thec = rng.multivariate_normal(meanc, covc)
    return (x, y, the), (xc, thec), (mean, eig_cov, phi)

def gen_suff(x, mu, t1, t2, cond=1e-6):
    mu = np.repeat(mu, x.shape[0])
    cov = t1 / t2 * np.exp(-np.square(cdist(x, x)) * t2 / x.shape[1])
    return mu, cov + cond * np.identity(len(mu)), np.linalg.eigh(cov + cond * np.identity(len(mu)))
 
def eval_logmargin(y, mean, cov, eig_cov, phi):
    l_mcov = eig_cov[0] + phi ** 2
    return np.sum(eval_norm_prec(y[np.newaxis], mean, eig_cov[1], 1 / l_mcov[np.newaxis]))

In [4]:
# sampling

def eval_logprior(data, param, hyper):
    the = param
    mean, _, eig_cov = hyper
    return np.sum(eval_norm_prec(the[np.newaxis], mean, eig_cov[1], 1 / eig_cov[0][np.newaxis]))

def cond_sampler(data, hyper, rng):
    y, phi = data
    mean, cov, _ = hyper
    return lgp_iid.sample_posterior(y, mean, cov, phi, rng)

def est_emus(lams, lame, y, x, mu, phi, rng, n_samples, n_burnin):
    suffs = [gen_suff(x, mu, *w_) for w_ in lams]
    suffe = [gen_suff(x, mu, *w_) for w_ in lame]
    return infemus.est_mlik(suffs, suffe, (y, phi), cond_sampler, eval_logprior, rng, n_samples, n_burnin)

def est_emus_flexgrid(n_variates, l, m, n, bds_t1, bds_t2, y, x, mu, phi, rng):
    lams_t1, lams_t2 = (np.exp(np.linspace(*bds, l + 1)) for bds in (bds_t1, bds_t2))
    lame_t1, lame_t2 = (np.exp(np.linspace(*bds, m + 1)) for bds in (bds_t1, bds_t2))
    lams = np.array(list(product(lams_t1, lams_t2)))
    lame = np.array(list(product(lame_t1, lame_t2)))
    return np.array([est_emus(lams, lame, y, x, mu, phi, rng, n, 0) for _ in range(n_variates)])

In [None]:
# generate input data

rng = np.random.default_rng(seed)
(x, y, _), (xc, thec), _ = gen_inputs(n_obs, n_dim, mu, t1, t2, phi / 2, sd_out, rng)

In [17]:
# evaluate ground truth

lame_t1, lame_t2 = (np.exp(np.linspace(*bds, m + 1)) for bds in (bds_t1, bds_t2))
lame = np.array(list(product(lame_t1, lame_t2)))
log_u = [eval_logmargin(y, *suffe_, phi) for suffe_ in [gen_suff(x, mu, *w_) for w_ in lame]]
u = np.exp(np.array(log_u) - logsumexp(log_u))
u_rec = np.reshape(u, 2 * (int(np.sqrt(len(u))),)).T

In [18]:
# sampling distribution of respective methods

u_variates_hires = [est_emus_flexgrid(n_variates, l_, m, n_, bds_t1, bds_t2, y, x, mu, phi, rng) for l_, n_ in hires_regime]
u_variates_lores = [est_emus_flexgrid(n_variates, l_, m, n_, bds_t1, bds_t2, y, x, mu, phi, rng) for l_, n_ in lores_regime]
u_variates_infill = [est_emus_flexgrid(n_variates, l_, m, n_, bds_t1, bds_t2, y, x, mu, phi, rng) for l_, n_ in infill_regime]

In [None]:
# draw scaling curves [fig 5]

err_hires = [np.mean(np.linalg.norm(u_ - u, axis=1)) for u_ in u_variates_hires]
err_lores = [np.mean(np.linalg.norm(u_ - u, axis=1)) for u_ in u_variates_lores]
err_infill = [np.mean(np.linalg.norm(u_ - u, axis=1)) for u_ in u_variates_infill]
n_hires = [a ** 2 * b for a, b in hires_regime]
n_lores = [a ** 2 * b for a, b in lores_regime]
n_infill = [a ** 2 * b for a, b in infill_regime]

plt.figure(figsize=(4, 3))
plt.axline((n_infill[-1] / 4, err_infill[-1] * 2), (n_infill[-1], err_infill[-1]), color='gray', linestyle='dashed')
plt.plot(n_hires, err_hires, marker='o')
plt.plot(n_lores, err_lores, marker='o')
plt.plot(n_infill, err_infill, marker='o')
plt.xscale('log', base=2)
plt.yscale('log', base=2)