In [None]:
%load_ext autoreload
%autoreload 2

import sys, os
from pathlib import Path
parent_dir = Path.cwd().resolve()
if parent_dir.name == 'testing':
    parent_dir = parent_dir.parent
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))
os.chdir(parent_dir)

import numpy as np
import jax.random as random
import jax.numpy as jnp
import time

import jax_gibbs as gs_jax
import utils
import cache as _cache

# Precompute sampling cache

Run this notebook once to populate `cache/` with Gibbs chains, KDE MLE samples, and full-data MCMC chains for all (k, m) configurations. After this, all analysis notebooks load from cache instantly.

In [None]:
# Configuration â€” must match what the analysis notebooks use
ks = [1.0, 2.0, 3.0, 5.0]
ms = [6, 10, 14, 20, 30, 50, 100]
mu_true = 2.0
T_gibbs = 50000
T_kde = 50000
T_fulldata = 50000
seed = 0
CACHE_DIR = 'cache'

base_params = {
    'mu_true': mu_true,
    'prior_mean': 0.0,
    'prior_std': 10.0,
    'proposal_std_mu': 0.9,
    'proposal_std_z': 0.03,
}

total = len(ks) * len(ms)
print(f"Will precompute {total} (k, m) configs")
print(f"  ks = {ks}")
print(f"  ms = {ms}")
print(f"  T_gibbs = {T_gibbs:,}, T_kde = {T_kde:,}, T_fulldata = {T_fulldata:,}")
print(f"  Cache dir: {CACHE_DIR}/")

In [None]:
# Precompute all Gibbs chains + KDE samples + full-data MCMC
key = random.PRNGKey(seed)
t_total_start = time.time()

for i_k, k in enumerate(ks):
    for i_m, m in enumerate(ms):
        idx = i_k * len(ms) + i_m + 1
        print(f"\n{'='*60}")
        print(f"[{idx}/{total}] k={k}, m={m}")
        print(f"{'='*60}")

        params = base_params.copy()
        params['k'] = k
        params['m'] = m
        params['num_iterations_T'] = T_gibbs

        gibbs_path = _cache.cache_path('gibbs', k, m, T_gibbs, seed, CACHE_DIR)
        kde_path = _cache.cache_path('kde', k, m, T_kde, seed, CACHE_DIR)
        fulldata_path = _cache.cache_path('fulldata', k, m, T_fulldata, seed, CACHE_DIR)

        # --- Generate data ---
        key, subkey = random.split(key)
        data = random.t(subkey, df=k, shape=(m,)) + mu_true
        mle = utils.get_mle(data, params)
        print(f"  MLE = {mle:.4f}")

        # --- Gibbs chain ---
        if _cache.is_cached(gibbs_path):
            print(f"  [skip] Gibbs already cached")
        else:
            print(f"  Running Gibbs sampler (T={T_gibbs:,})...")
            key, key_gibbs = random.split(key)
            t0 = time.time()
            gibbs_results = gs_jax.run_gibbs_sampler_mle_jax(key_gibbs, mle, params.copy())
            elapsed = time.time() - t0
            mu_chain = np.array(gibbs_results['mu_chain'])
            _cache.save_gibbs(gibbs_path, mu_chain, data, mle)
            print(f"  Gibbs done ({elapsed:.1f}s), saved to {gibbs_path}")

        # --- KDE MLE samples ---
        if _cache.is_cached(kde_path):
            print(f"  [skip] KDE samples already cached")
        else:
            kde_params = params.copy()
            if k == 1.0:
                kde_params['kde_bw_method'] = 0.001
            print(f"  Computing KDE MLE samples ({T_kde:,} simulations)...")
            t0 = time.time()
            mle_samples = utils.get_benchmark_mle_samples(kde_params, num_simulations=T_kde)
            elapsed = time.time() - t0
            _cache.save_kde_samples(kde_path, mle_samples)
            print(f"  KDE samples done ({elapsed:.1f}s), saved to {kde_path}")

        # --- Full-data MCMC ---
        if _cache.is_cached(fulldata_path):
            print(f"  [skip] Full-data MCMC already cached")
        else:
            fd_params = params.copy()
            fd_params['num_iterations_T'] = T_fulldata
            print(f"  Running full-data MCMC (T={T_fulldata:,})...")
            key, key_full = random.split(key)
            t0 = time.time()
            full_results = gs_jax.run_metropolis_x_jax(key_full, jnp.asarray(data), fd_params.copy())
            elapsed = time.time() - t0
            mu_chain_full = np.array(full_results['mu_chain'])
            _cache.save_fulldata(fulldata_path, mu_chain_full, data, mle)
            print(f"  Full-data MCMC done ({elapsed:.1f}s), saved to {fulldata_path}")

t_total = time.time() - t_total_start
print(f"\n{'='*60}")
print(f"All done! Total time: {t_total:.1f}s ({t_total/60:.1f}min)")
print(f"{'='*60}")

In [None]:
# Verify: list all cached files
import glob
files = sorted(glob.glob(os.path.join(CACHE_DIR, '*.npz')))
print(f"Cached files ({len(files)}):")
for f in files:
    size_kb = os.path.getsize(f) / 1024
    print(f"  {f} ({size_kb:.1f} KB)")