In [1]:
import numpy as np
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import torch
from sbi.utils import BoxUniform

from sbi.inference import simulate_for_sbi
from sbi.utils.user_input_checks import process_simulator, process_prior, check_sbi_inputs


In [2]:
v = 1
a = 1

# SBI

## prior

In [3]:
# low = torch.tensor([0.5, 0.5])
# high = torch.tensor([2, 2])

low = torch.tensor([0.01,0.01])
high = torch.tensor([10,10])

prior = BoxUniform(low=low, high=high)
prior, num_parameters, prior_returns_numpy = process_prior(prior)

## simulate single bound

In [4]:
def simulate_PA(params):
    v,a = params
    dt = 1e-4; dB = 1e-2
    DV = 0; t = 0
    while True:
        DV += v*dt + np.random.normal(0, dB)
        t += dt

        if DV >= a:
            return t


## sbi simulator

In [5]:
simulator = process_simulator(simulate_PA, prior, prior_returns_numpy)

In [6]:
check_sbi_inputs(simulator, prior)

In [None]:
N_sim = int(100e3)

proposal = prior
theta, x_o = simulate_for_sbi(simulator=simulator,\
                            proposal=proposal,
                            num_simulations=N_sim,
                            num_workers=32)

  0%|          | 0/100000 [00:00<?, ?it/s]

## train network

In [None]:
from sbi.inference import SNLE
trainer = SNLE()
x_o_2d = torch.atleast_2d(x_o)
estimator = trainer.append_simulations(theta, x_o_2d.T).train(training_batch_size=512)

## likelihood

In [None]:
bin_width = 0.1
bins = np.arange(0, 10, bin_width)
t_pts = bins[:-1] + bin_width/2
sim_results = Parallel(n_jobs=-1)(delayed(simulate_PA)([v,a]) for _ in range(int(50e3)))



In [None]:
og_params = torch.Tensor([v, a])
loglike = estimator.log_prob(torch.tensor(t_pts).unsqueeze(0).to(dtype=torch.float32), torch.tensor([[v, a]] * len(t_pts)) )
like = torch.exp(loglike)

In [None]:
plt.plot(t_pts, like.squeeze().detach().numpy())
plt.hist(sim_results, bins=bins, density=True, alpha=0.5);

In [None]:
plt.plot(trainer.summary['validation_loss'], label='validation loss')
plt.plot(trainer.summary['training_loss'], label='trainin loss')
plt.legend()

# check simulator?

In [None]:
# prior_1 = BoxUniform(low=torch.tensor([1,1]), high=torch.tensor([1,1]))
# prior_1, num_parameters_1, prior_returns_numpy_1 = process_prior(prior_1)
# check_sbi_inputs(simulator, prior_1)
# proposal_1 = prior_1
# theta_1, x_o_1 = simulate_for_sbi(simulator=simulator,\
#                             proposal=proposal_1,
#                             num_simulations=N_sim,
#                             num_workers=32)

# x_o_1_np = x_o_1.numpy()

# plt.hist(x_o_1_np, bins=bins, density=True, alpha=0.5, histtype='step', lw=2, color='r');
# plt.hist(sim_results, bins=bins, density=True, alpha=0.5, histtype='step', lw=2, color='b', ls='--');