In [6]:
# dependencies

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

import gzip
import time
from pickle import dump, load

import numpy as np

from switch import mcmc, sde
from switch.models import loggrowth
from switch.mjplib.skeleton import Skeleton

In [None]:
# configure design

n_blocks = 4
obs_dens = 1
block_length = 64

# configure algorithm

n_samples = 10000
n_init = 1000
n_thin = 10
init_scale_h = 2
pr_portkey_h = 1e-2
n_dc_splits = 1
seed = 0

# configure prior

prior_shape_jumps = 1
prior_rate_jumps = 2 ** 6

In [8]:
# generate input time series

ome = np.random.default_rng(seed)
init_v = 1/np.sqrt(2)
thi0 = np.array([[1, np.sqrt(2), 1/16], [1, 1/np.sqrt(2), 1/16]])
jump_t = np.append(0, block_length * (1 + np.arange(2 * n_blocks)))
jump_yt = np.append(0, np.tile([1, 0], n_blocks))
y0 = Skeleton(jump_t, jump_yt, fin_t=jump_t[-1] + block_length / 2)

t = np.linspace(0, y0.fin_t, int(y0.fin_t * 1) + 1)
vt = [init_v]
for dt, y_ in zip(np.diff(t), y0(t[:-1])):
    vt.append(sde.sample_forward(thi0, dt, y_, vt[-1], loggrowth.mod, ome, 1))
vt = np.array(vt)

tt = [t[0]]
vtt = [vt[0]]
for dt, y_, v0, v1 in zip(np.diff(t), y0(t[:-1]), vt, vt[1:]):
    v_ = sde.sample_bridge(thi0, np.linspace(0, dt, obs_dens + 1)[1:-1], dt, y_, v0, v1, loggrowth.mod, ome)
    tt.extend(list(np.linspace(0, dt, obs_dens + 1)[1:-1] + tt[-1]) + [dt + tt[-1]])
    vtt.extend(list(v_) + [v1])
t = np.array(tt)
vt = np.array(vtt)
yt = y0(t)

t = t[((obs_dens * block_length) // 2):] - t[((obs_dens * block_length) // 2)]
vt = vt[((obs_dens * block_length) // 2):]
yt = yt[((obs_dens * block_length) // 2):]

In [9]:
# set up sampler

samples = []
sampler = mcmc.sample_posterior(
    t, 
    vt, 
    2, 
    prior_shape_jumps, 
    prior_rate_jumps, 
    loggrowth.mod, 
    ome, 
    init_scale_h=init_scale_h, 
    n_init=n_init, 
    n_dc_splits=n_dc_splits, 
    pr_portkey_thi=1/t[-1], 
    pr_portkey_h=pr_portkey_h)

In [10]:
# simulate markov chain

for _ in range(n_init + n_samples):
    t0 = time.process_time(); thi, lam, h, z = next(sampler); t1 = time.process_time()
    itime = t1 - t0
    if not _ % n_thin:
        samples.append((thi, lam, h, {'itime': itime}))

In [11]:
# dump to disc

state = {k: v for k, v in sampler.gi_frame.f_locals.items() if k in
         ('thi', 'lam', 'h', 'z', 't', 'vt', 'hyper_lam', 'thi_samplers', 'node_sampler', 'ctrl', 'ome')}
with gzip.open('paper/simstud_base_state.pkl', 'wb') as f:
    dump(state, f)
with gzip.open('paper/simstud_base_samples.pkl', 'wb') as f:
    dump(samples, f)

In [12]:
# to restart from saved state

# with gzip.open('paper/simstud_base_state.pkl', 'rb') as f:
#     state = load(f)
# with gzip.open('paper/simstud_base_samples.pkl', 'rb') as f:
#     samples = load(f)

# sampler = mcmc.resume_sampling(**state, mod=loggrowth.mod)