In [2]:
# dependencies

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from ucimlrepo import fetch_ucirepo 

from infemus import infemus
from infemus.models import logitgp
from infemus.tools.metropolis_mv import eval_norm_prec

sns.set_theme()

In [3]:
# config

seed = 0
n_variates = 128
l = 16
n_interp = 2
n_samples = 256
n_burnin = 256
n_samples_gt = 8192
bds_t1 = (-.5, 2.5)
bds_t2 = (-6, 0)

In [4]:
# sampling

def gen_suff(x, mu, t1, t2, nug=1e-6):
    mu = np.repeat(mu, x.shape[0])
    cov = t1 / t2 * np.exp(-np.square(cdist(x, x)) * t2 / x.shape[1])
    return mu, cov + nug * np.identity(len(mu)), np.linalg.eigh(cov + nug * np.identity(len(mu)))

def eval_logprior(data, param, hyper):
    the = param
    mean, _, eig_cov = hyper
    return np.sum(eval_norm_prec(the[np.newaxis], mean, eig_cov[1], 1 / eig_cov[0][np.newaxis]))

def cond_sampler(data, hyper, rng):
    y, = data
    mean, _, eig_cov = hyper
    return logitgp.sample_posterior(y, mean, eig_cov, rng)

def est_emus(lams, lame, y, x, mu, rng, n_samples, n_burnin):
    suffs = [gen_suff(x, mu, *w_) for w_ in lams]
    suffe = [gen_suff(x, mu, *w_) for w_ in lame]
    return infemus.est_mlik(suffs, suffe, (y,), cond_sampler, eval_logprior, rng, n_samples, n_burnin)

In [5]:
# fetch dataset 
heart_disease = fetch_ucirepo(id=45)

# data (as pandas dataframes) 
x = heart_disease.data.features.loc[~heart_disease.data.features.isna().any(axis=1)].values[:100]
y = (heart_disease.data.targets.loc[~heart_disease.data.features.isna().any(axis=1), 'num'].values > 0)[:100]

In [6]:
x_std = x
x_std = (x_std - np.mean(x_std, 0)) / np.std(x_std, 0)

In [7]:
# generate grid

lams_t1, lams_t2 = (np.exp(np.linspace(*bds, l + 1)) for bds in (bds_t1, bds_t2))
lame_t1, lame_t2 = (np.exp(np.linspace(*bds, l * n_interp + 1)) for bds in (bds_t1, bds_t2))

lams = np.array(list(product(lams_t1, lams_t2)))
lame = np.array(list(product(lame_t1, lame_t2)))

t1_rec, t2_rec = np.meshgrid(lame_t1, lame_t2)

In [8]:
# estimate sampling distribution

rng = np.random.default_rng(seed)
u_est = np.array([est_emus(lams, lame, y, x_std, np.zeros_like(x.shape[1]), rng, n_samples, n_burnin) for _ in range(n_variates)])

In [None]:
# draw high-precision variate and typical variate

rng = np.random.default_rng(seed)

u = est_emus(lams, lame, y, x_std, np.zeros_like(x.shape[1]), rng, n_samples_gt, n_burnin)
u_rec = np.reshape(u, 2 * (int(np.sqrt(len(u))),)).T

u_est_rec = np.reshape(u_est[0], 2 * (int(np.sqrt(len(u))),)).T

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3), sharex=True, sharey=True)

ax1.pcolormesh(t1_rec, t2_rec, u_rec, cmap='magma')
ax1.set_xscale('log', base=2)
ax1.set_yscale('log', base=2)
ax1.set_xlabel(r'$\tau_{1}$')
ax1.set_ylabel(r'$\tau_{2}$')

ax2.pcolormesh(t1_rec, t2_rec, u_est_rec, cmap='magma')
ax2.set_xscale('log', base=2)
ax2.set_yscale('log', base=2)
ax2.set_xlabel(r'$\tau_{1}$')

In [None]:
# profile error plots [fig 7]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3), sharex='col', sharey=True)

ax1.plot(lame_t1, np.max(u_rec, 0), color='black')
u_est_prof_t1 = [np.max(np.reshape(u_, 2 * (int(np.sqrt(len(u_))),)), 1) for u_ in u_est]
ax1.fill_between(lame_t1, *np.percentile(u_est_prof_t1, [12.5, 87.5], 0), alpha=.5)
ax1.set_xlabel(r'$\tau_{1}$')
ax1.set_xscale('log', base=2)
ax1.set_yticklabels([])

ax2.plot(lame_t2, np.max(u_rec, 1), color='black')
u_est_prof_t2 = [np.max(np.reshape(u_, 2 * (int(np.sqrt(len(u_))),)), 0) for u_ in u_est]
ax2.fill_between(lame_t2, *np.percentile(u_est_prof_t2, [12.5, 87.5], 0), alpha=.5)
ax2.set_xlabel(r'$\tau_{2}$')
ax2.set_xscale('log', base=2)