In [1]:
import torch
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer

In [2]:
import BioMime.utils.basics as bm_basics
import BioMime.models.generator as bm_gen

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = bm_basics.update_config('../BioMime/BioMime/configs/config.yaml')
generator = bm_gen.Generator(config.Model.Generator)
generator = bm_basics.load_generator('../BioMime/ckp/model_linear.pth', generator, 'cuda')
generator = generator.to(device)

In [3]:
from utils import generate_muap

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
def simulator(pars):
    if pars.ndim == 1:
        pars = pars[None, :]

    n_MU = pars.shape[0]
    sim_muaps = []

    for _ in range(10):
        cond = pars.to(device)
        sim = generator.sample(n_MU, cond.float(), cond.device)

        sim = sim.to("cpu")
        if n_MU == 1:
            sim = sim.permute(1, 2, 0).detach().numpy()
        else:
            sim = sim.permute(0, 2, 3, 1).detach().numpy()
        sim_muaps.append(sim)

    muap = np.array(sim_muaps).mean(0)
    
    return torch.from_numpy(muap.flatten())

In [5]:
num_dim = 6
prior = utils.BoxUniform(low=0.5 * torch.ones(num_dim), high=torch.ones(num_dim))

In [None]:
# Other methods are "SNLE" or "SNRE".
posterior = infer(simulator, prior, method="SNPE", num_simulations=100000)

Running 100000 simulations.:   0%|          | 0/100000 [00:00<?, ?it/s]

 Training neural network. Epochs trained: 191

In [None]:
# _ = inference.append_simulations(theta, x).train(**train_kwargs)

In [None]:
a

In [None]:
import h5py

with h5py.File('./lhs1000_biomime.h5', 'r') as f:
    muaps = f['muaps'][()]
    params = f['params'][()]

In [None]:
observation_pars = torch.from_numpy(params[12])
observation = simulator(observation_pars)

In [None]:
samples = posterior.sample((10000,), x=observation)
log_probability = posterior.log_prob(samples, x=observation)
fig, ax = analysis.pairplot(
    samples, 
    limits=[[0.5, 1] for _ in range(6)], 
    figsize=(6, 6),
    points=observation_pars,
    labels=('FD', 'D', 'A', 'IZ', 'CV', 'FL'),
)

In [None]:
def get_credible_intervals(samples):
    n_dims = samples.shape[-1]
    interval = np.zeros((n_dims, 2))
    for i in range(n_dims):
        interval[i] = np.percentile(samples[:, i], q=[2.5, 97.5])
        
    return interval.T

In [None]:
maps = []
cis = []
for par in params:
    observation_pars = torch.from_numpy(par)
    observation = simulator(observation_pars)
    samples = posterior.sample((1000,), x=observation, show_progress_bars=False)
    log_probability = posterior.log_prob(samples, x=observation)
    maps.append(samples[np.argmax(log_probability)])
    cis.append(get_credible_intervals(samples))

In [None]:
maps = np.array(maps)
cis = np.array(cis)

In [None]:
with h5py.File('./lhs1000_npe.h5', 'w') as f:
    f.create_dataset('true_muaps', data=muaps)
    f.create_dataset('true_params', data=params)
    f.create_dataset('MAP_estimates', data=maps)
    f.create_dataset('conf_intervals', data=cis)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(9, 6))
titles = ['FD', 'D', 'A', 'IZ', 'CV', 'FL']

for i in range(6):
    col, row = divmod(i, 2)
    errors = np.abs(cis[:, :, i].T-maps[:, i])
    axes[row, col].errorbar(params[:, i], maps[:, i], yerr=errors, fmt='o', linewidth=0.75, ms=2)
    axes[row, col].set_ylim([0.5, 1])
    axes[row, col].set_title(titles[i])

plt.tight_layout()
sns.despine()
plt.show()

### SNPE 

In [None]:
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi.utils.get_nn_models import posterior_nn

In [None]:
# 2 rounds: first round simulates from the prior, second round simulates parameter set
# that were sampled from the obtained posterior.
num_rounds = 5

# The specific observation we want to focus the inference on.
true_par = torch.from_numpy(params[12])
x_obs = simulator(true_par)
prior = utils.BoxUniform(low=0.5 * torch.ones(num_dim), high=torch.ones(num_dim))

posteriors = []
inference = SNPE(prior=prior)
proposal = prior

for _ in range(num_rounds):
    theta, x = simulate_for_sbi(
        simulator, 
        proposal, 
        num_simulations=2000, 
        simulation_batch_size=1,
    )

    # In `SNLE` and `SNRE`, you should not pass the `proposal` to `.append_simulations()`
    density_estimator = inference.append_simulations(
        theta, x, proposal=proposal
    ).train()
    posterior = inference.build_posterior(density_estimator)
    posteriors.append(posterior)
    proposal = posterior.set_default_x(x_obs)

In [None]:
samples = posterior.sample((10000,), x=x_obs)
log_probability = posterior.log_prob(samples, x=x_obs)
fig, ax = analysis.pairplot(
    samples, 
    limits=[[0.5, 1] for _ in range(6)], 
    figsize=(6, 6),
    points=true_par,
    labels=('FD', 'D', 'A', 'IZ', 'CV', 'FL'),
)

In [None]:
def run_snle(simulator, x_obs, proposal_prior, rounds=5, sim_budget=10000):
    
    return posterior

### SMC-ABC

In [None]:
# Use existing SMC-ABC implementation instead of the pyABC impl.
from sbi.inference import prepare_for_sbi, simulate_for_sbi, SMCABC

simulator, prior = prepare_for_sbi(simulator, prior)

In [None]:
inference = SMCABC(simulator, prior, show_progress_bars=False)

In [None]:
observation = torch.zeros(3)

In [None]:
posterior, summary = inference(
    x_o=observation,
    num_particles=50,
    num_initial_pop=100,
    num_simulations=100000,
    epsilon_decay=0.1,
    kde=True,
    return_summary=True,
)

In [None]:
samples = posterior.sample((1000,))
log_probability = posterior.log_prob(samples)
_ = analysis.pairplot(samples, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6))