In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

from sbi.inference import NPE
from sbi.analysis import pairplot
from sbi.utils import BoxUniform

In [2]:
num_dim = 3

In [13]:
def simulator(theta):
    m,k,b = theta
    ts = 5.0
    dt = 0.01
    step = int(ts/dt)
    t = np.linspace(0, ts, step)
    xk = np.array([10,3])
    states = []

    for i in range(step): 
        xk = xk + np.array([
            (0, dt), 
            (-k*dt/m, -b*dt/m)
        ]) @ xk + np.array((0, dt/m)) * np.cos(i*dt * 10)
        states.append(xk.tolist())

    return torch.tensor(states).flatten()

def eval(t, x, val):
    index = np.argmin(np.abs(t - val))
    return x[index]


In [15]:
# theta_o = np.array((10,5,2))
# t, x = simulator(theta_o)

In [16]:
# plt.plot(t, x[:, 0], "k")

In [17]:
# plt.plot(t, x[:, 1], "k")

In [18]:
def get_3_values(t, x):
    """
    Return 3 'x' values corresponding to t=-0.5,0,0.75 as summary statistic vector
    """
    return np.array(
        [
            eval(t, x, 0.5),
            eval(t, x, 2.5),
            eval(t, x, 4.75),
        ]
    )

In [19]:
prior = BoxUniform(
    low=torch.as_tensor([1.0, 1.0, 0.5]), 
    high=torch.as_tensor([20.0, 10.0, 2.0])
)

In [20]:
theta_samples = prior.sample((1000,))
x_samples = torch.stack([simulator(theta) for theta in theta_samples])

In [None]:
x_samples

In [85]:
theta

tensor([[ 5.0114,  8.5277,  1.3535],
        [12.0914,  1.3118,  1.4654],
        [15.7647,  4.9210,  1.4610],
        ...,
        [18.9567,  2.8286,  1.0826],
        [12.1697,  4.4626,  0.8181],
        [ 9.5686,  9.6051,  1.4269]])

In [300]:
vals = get_3_values(t, x)
vals = torch.as_tensor(x, dtype=torch.float32)

In [301]:
inference = NPE(prior)
_ = inference.append_simulations(theta, vals).train()
posterior = inference.build_posterior()

AssertionError: Number of parameter sets (=1000 must match the number of simulation outputs (=5000)

In [293]:
x_o = torch.as_tensor(get_3_values(t, x), dtype=float)

In [295]:
theta_p = posterior.sample((10000,), x=x_o)

fig, axes = pairplot(
    theta_p,
    limits=list(zip(prior_min, prior_max)),
    ticks=list(zip(prior_min, prior_max)),
    figsize=(7, 7),
    labels=["a", "b", "c"],
    fig_kwargs=dict(
        points_offdiag={"markersize": 6},
        points_colors="r",
    ),
    points=theta_o,
);

NameError: name 'posterior' is not defined