In [1]:
import sys
sys.path.append("../")

from bayesflow import Approximator
from bayesflow import OfflineDataset
from bayesflow.simulators import TwoMoonsSimulator
from bayesflow.networks import CIF

In [2]:
batch_size = 128
num_batches = 32

simulator = TwoMoonsSimulator()
data = simulator.sample((batch_size * num_batches,))
dataset = OfflineDataset(data, workers=4, batch_size=batch_size)
print("Batch size:", dataset.batch_size)
print([key for key in dataset[0].keys()])


Batch size: 128
['r', 'alpha', 'theta', 'x']


In [3]:
cif = CIF()
approximator = Approximator(
	inference_network=cif,
	inference_variables=["theta"],
	inference_conditions=["r", "alpha", "x"]
)
approximator.compile(optimizer="adam")
approximator.build_from_data(next(iter(dataset)))

In [4]:
samples = approximator.sample((128,1), next(iter(dataset)))
print(samples)

{'theta': <tf.Tensor: shape=(128, 2, 2), dtype=float32, numpy=
array([[[ 1.60534620e-01,  1.25463843e-01],
        [ 9.13245618e-01,  1.32534459e-01]],

       [[-3.24484736e-01, -1.19637430e-01],
        [-7.08619729e-02, -6.32366687e-02]],

       [[-2.37683922e-01,  1.95614010e-01],
        [-2.25730434e-01, -1.33314103e-01]],

       [[-2.87156492e-01, -6.46867380e-02],
        [-3.79036993e-01,  2.85323888e-01]],

       [[ 4.37962919e-01,  2.83477474e-02],
        [-1.82896063e-01, -1.95387423e-01]],

       [[-1.80056974e-01, -4.36633736e-01],
        [-3.34053904e-01, -5.16214848e-01]],

       [[-4.29401249e-01, -4.31349780e-03],
        [ 3.40517402e-01,  3.17731947e-01]],

       [[-5.30258045e-02,  1.11299232e-01],
        [-1.16609931e-01, -2.42464140e-01]],

       [[-5.37763536e-01,  3.29821497e-01],
        [-3.29010077e-02,  2.67980415e-02]],

       [[-6.51495159e-01, -3.74955446e-01],
        [-8.05761293e-02,  5.89222133e-01]],

       [[ 1.89012989e-01,  4.73343194

In [5]:
approximator.fit(dataset, epochs=5)

Epoch 1/5


In [None]:
samples = approximator.sample((128,1), next(iter(dataset)))
print(samples)



{'x': <tf.Tensor: shape=(128, 2, 2), dtype=float32, numpy=
array([[[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan],
        [nan, nan]],

       [[nan, nan