In [None]:
# dependencies

import gzip
from pickle import dump, load

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from switch.mjplib.skeleton import Skeleton
from switch.misc import plot
from switch import sde
from switch.models import loggrowth

sns.set_theme()

In [None]:
## plot data for base design || figure 7

n_blocks = 4
obs_dens = 1
block_length = 64


# generate fixture

ome = np.random.default_rng(0)
init_v = 1/np.sqrt(2)
thi0 = np.array([[1, np.sqrt(2), 1/16], [1, 1/np.sqrt(2), 1/16]])
jump_t = np.append(0, block_length * (1 + np.arange(2 * n_blocks)))
jump_yt = np.append(0, np.tile([1, 0], n_blocks))
y0 = Skeleton(jump_t, jump_yt, fin_t=jump_t[-1] + block_length / 2)


# sample data

t = np.linspace(0, y0.fin_t, int(y0.fin_t * 1) + 1)
vt = [init_v]
for dt, y_ in zip(np.diff(t), y0(t[:-1])):
    vt.append(sde.sample_forward(thi0, dt, y_, vt[-1], loggrowth.mod, ome, 1))
vt = np.array(vt)

tt = [t[0]]
vtt = [vt[0]]
for dt, y_, v0, v1 in zip(np.diff(t), y0(t[:-1]), vt, vt[1:]):
    v_ = sde.sample_bridge(thi0, np.linspace(0, dt, obs_dens + 1)[1:-1], dt, y_, v0, v1, loggrowth.mod, ome)
    tt.extend(list(np.linspace(0, dt, obs_dens + 1)[1:-1] + tt[-1]) + [dt + tt[-1]])
    vtt.extend(list(v_) + [v1])
t = np.array(tt)
vt = np.array(vtt)
yt = y0(t)

t = t[((obs_dens * block_length) // 2):] - t[((obs_dens * block_length) // 2)]
vt = vt[((obs_dens * block_length) // 2):]
yt = yt[((obs_dens * block_length) // 2):]

plt.figure(figsize=(2*(1+np.sqrt(5)), 2))

plt.plot(t, np.where((yt == 0) | (np.roll(yt, 1) == 0) | (np.roll(yt, -1) == 0), vt, np.nan), linewidth=1)
plt.plot(t, np.where((yt == 1) | (np.roll(yt, 1) == 1) | (np.roll(yt, -1) == 1), vt, np.nan), linewidth=1)
plt.xlabel(r'$t$')
plt.ylabel(r'$v_{t}$')
plt.xticks(2 * np.arange(n_blocks + 1) * block_length)
plt.xlim(0, 2 * n_blocks * block_length)
plt.yticks([thi0[0, 1], thi0[1, 1]], labels=[r'$\kappa_{1}$', r'$\kappa_{2}$'])

plt.savefig('paper/simstud_data.pdf', backend='pgf', bbox_inches='tight')

In [None]:
## compute and plot performance metrics || figure 8

def format_samples(path, n_thin=10, n_discard=1000, meta={}):
    with gzip.open(path, 'rb') as f:
        samples = load(f)[n_discard // n_thin::]
    order = [np.argsort(s[0][:, 1]) for s in samples]
    param_samples = [np.hstack([s[0][o, :], -np.diag(s[1])[o][:, np.newaxis]]) for s, o in zip(samples, order)]
    df = []
    for i in range(param_samples[0].shape[0]):
        samples_ = np.array([samples_[i] for samples_ in param_samples])
        df_ = pd.DataFrame({'beta': samples_[:, 0], 'kappa': samples_[:, 1], 'rho': samples_[:, 2], 'lambda': samples_[:, 3]})
        df_['state'] = i
        df_['iter'] = n_discard + df_.index * n_thin
        for k, v in meta.items():
            df_[k] = v
        df.append(df_)
    df_long = pd.concat(df, axis=0).melt(id_vars=['iter', 'state'] + list(meta.keys()), var_name='param')
    return df_long.set_index(['iter', 'state', 'param'] + list(meta.keys())).unstack('iter').value

def eval_summaries(sample_path, n_thin=10, n_discard=1000, meta={}):
    with gzip.open(sample_path, 'rb') as f:
        samples = load(f)[n_discard // n_thin::]
    itime = [s_[-1]['itime'] for s_ in samples]
    df = pd.DataFrame({'itime': itime})
    df['iter'] = n_discard + df.index * n_thin
    for k, v in meta.items():
        df[k] = v
    df_long = df.melt(id_vars=['iter'] + list(meta.keys()), var_name='param')
    return df_long.set_index(['iter', 'param'] + list(meta.keys())).unstack('iter').value

itime_samples = pd.concat([
    eval_summaries(path, meta={'regime': regime})
    for path, regime in (
        ('paper/simstud_base_samples.pkl', 'base'),
        ('paper/simstud_extend_samples.pkl', 'outfill'),
        ('paper/simstud_infill_samples.pkl', 'infill')
    )
])

param_samples = pd.concat([
    format_samples(path, meta={'regime': regime})
    for path, regime in (
        ('paper/simstud_base_samples.pkl', 'base'),
        ('paper/simstud_extend_samples.pkl', 'outfill'),
        ('paper/simstud_infill_samples.pkl', 'infill')
    )
])

n_lags = 512
param_samples = param_samples.T.ffill().T
acf = param_samples.apply(lambda x: plot.est_acf(x.dropna().values, n_lags), 1, False, 'expand').rename_axis(columns='lag')
ess = pd.DataFrame({'iat': acf.apply(lambda x:plot.est_int_autocor(x.values), 1, False, 'expand').rename('iat')})
ess['spi'] = 2 * ess['iat']
ess['sps'] = (2 * ess.unstack('regime')['iat'] * itime_samples.mean(1).droplevel(0)).stack()
ess = ess.reset_index()
itimes_long = itime_samples.droplevel(0).stack().rename('value').reset_index()

fig, g = plt.subplots(1, 3, figsize=(8, 2))

sns.boxplot(itimes_long, x='value', y='regime', whis=16, fliersize=1, ax=g[0])
g[0].set(xlabel='Seconds/Iterations', ylabel=None)
g[0].set_xscale('log', base=2)

sns.stripplot(ess, x='spi', y='regime', marker='o', dodge=True, ax=g[1])
g[1].set(xlabel='Iterations/Eff Samples', ylabel='', yticks=[])
g[1].set_xscale('log', base=2)

sns.stripplot(ess, x='sps', y='regime', marker='o', dodge=True, ax=g[2])
g[2].set(xlabel='Seconds/Eff Samples', ylabel='', yticks=[])
g[2].set_xscale('log', base=2)

plt.savefig('paper/simstud_performance.pdf', backend='pgf', bbox_inches='tight')