In [None]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from pathlib import Path
import pymc as pm
import arviz as az
import multiprocessing as mp
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng

from tbdynamics.constants import BURN_IN, OPTI_DRAWS
from autumn.infrastructure.remote import springboard
from tbdynamics.calib_utils import get_bcm
import pandas as pd
from estival.utils.sample import SampleTypes

In [None]:
params = {
    "start_population_size": 2300000.0,
    "seed_time": 1810.0,
    "seed_num": 10.0,
    "seed_duration": 3.0,
    "progression_multiplier": 1.8,
    "screening_scaleup_shape": 0.1,
    "screening_inflection_time": 1993.0,
    "screening_end_asymp": 0.65,
}

bcm = get_bcm(params)

In [None]:
def optimize_ng(sample):
    opt = eng.optimize_model(bcm, budget=2000, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=4)
    rec= opt.minimize(2000)
    return rec.value[1]
opt_samples = map_parallel(optimize_ng, bcm.sample.lhs(16, SampleTypes.LIST_OF_DICTS), n_workers=4, mode="process")
lle_samps = esamp.likelihood_extras_for_samples(opt_samples, bcm)
best_ll_idx = lle_samps.sort_values("logposterior",ascending=False).index
init_samps= bcm.sample.convert(opt_samples,SampleTypes.SAMPLEITERATOR)[best_ll_idx].convert(SampleTypes.LIST_OF_DICTS)[0]
with pm.Model() as model:
    variables = epm.use_model(bcm)
    idata_raw = pm.sample(
        step=[pm.DEMetropolisZ(variables)],
        draws=100000,
        tune=20000,
        cores=6,
        discard_tuned_samples=False,
        chains=6,
        progressbar=True,
        initvals=init_samps,
    )

In [None]:
import arviz as az

In [None]:
az.summary(idata_raw)

In [None]:
az.plot_trace(idata_raw, figsize=(16,3.2*len(idata_raw.posterior)) ,compact=False)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import arviz as az

# Simulating data: 2 variables, 4 chains, 100 draws each
np.random.seed(0)
data_var1 = np.random.randn(4, 100)  # Variable 1
data_var2 = np.random.randn(4, 100) * 2 + 5  # Variable 2, different scale and offset

# Convert to xarray Dataset
dataset = xr.Dataset({
    "var1": (("chain", "draw"), data_var1),
    "var2": (("chain", "draw"), data_var2)
}, coords={
    "chain": range(data_var1.shape[0]),
    "draw": range(data_var1.shape[1])
})

# Creating InferenceData is optional here, used only if you want to leverage ArviZ features
idata = az.InferenceData(posterior=dataset)

# Plotting
n_vars = len(idata.posterior.data_vars)
fig, axs = plt.subplots(n_vars, 1, figsize=(12, 6 * n_vars))

for i, var_name in enumerate(idata.posterior.data_vars):
    for chain_idx in idata.posterior.chain:
        axs[i].plot(idata.posterior[var_name].sel(chain=chain_idx), label=f'Chain {chain_idx.values}')
    axs[i].set_title(f'Trace for {var_name}')
    axs[i].legend()
    axs[i].set_ylabel('Value')
    axs[i].set_xlabel('Draw')

plt.tight_layout()
plt.show()