# User-defined functions

In [2]:
import numpy as np
np.set_printoptions(suppress=True)

In [3]:
from bayesflow.forward_inference import *
from bayesflow.configuration import *
from bayesflow.amortized_inference import AmortizedLikelihood
from bayesflow.networks import InvertibleNetwork
from bayesflow.trainers import Trainer

In [4]:
%load_ext autoreload
%autoreload 2

# User-defined functions

In [5]:
def prior_fun():
    return np.random.normal(size=3)

def n_obs_gen():
    return np.random.randint(20, 100+1)

def generator(params, n_obs):
    return np.random.normal(params, size=(n_obs, params.shape[0]))

# Generative Model Interface

In [6]:
sim_context = ContextGenerator(
    non_batchable_context_fun=n_obs_gen,
)

In [7]:
sim = Simulator(simulator_fun=generator, context_generator=sim_context)

In [8]:
gen_model = GenerativeModel(prior_fun, sim)

Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 3)
Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 27, 3)
No prior non-batchable context provided.
No prior batchable context provided.
Could not determine shape of simulation non-batchable context. Type appears to be non-array: <class 'int'>,                                    so make sure your input configurator takes cares of that!
No simulation batchable context provided.


# Networks and Amortizer

In [9]:
inference_net = InvertibleNetwork({'n_params': 3})
surrogate = AmortizedLikelihood(inference_net)

# Trainer and Simulation-Based Training

In [11]:
trainer = Trainer(amortizer=surrogate, generative_model=gen_model)

Performing a consistency check with provided modules...Done.


In [12]:
h = trainer.train_online(epochs=1, iterations_per_epoch=50, batch_size=32)

Training epoch 1:   0%|          | 0/50 [00:00<?, ?it/s]

# Some Checks

In [13]:
inp_data = trainer.configurator(trainer.generative_model(32))

## Likelihood estimation

In [14]:
surrogate.log_likelihood(inp_data)

array([[ 8.006765 ,  5.7138405,  6.270076 , ...,  7.929564 ,  4.8080263,
         4.0986767],
       [ 5.5476646,  4.8460035,  5.2039223, ...,  7.567896 ,  5.5118637,
         4.7959375],
       [ 3.7478576,  6.58309  , -1.0359812, ...,  3.758586 ,  4.55163  ,
         4.5234065],
       ...,
       [ 5.722727 ,  4.624796 ,  5.6521225, ...,  5.6183677,  5.6762486,
         2.930715 ],
       [ 5.9108696,  5.1875954,  4.8316517, ...,  6.9187045,  4.294381 ,
         5.995303 ],
       [ 5.726029 ,  6.0359755,  6.496496 , ...,  4.6192036,  7.2494993,
         6.3057346]], dtype=float32)

## Synthetic data generation

In [15]:
surrogate.sample(inp_data, n_samples=24)

array([[[ 1.0670216 , -0.02788644,  0.6036515 ],
        [ 0.79939246, -0.09182572, -1.478431  ],
        [-1.7825991 ,  1.3088417 , -0.0720353 ],
        ...,
        [-1.1067494 , -0.00606982, -0.8406898 ],
        [ 0.48378423, -0.29521313, -0.14911191],
        [-1.8107058 , -0.88614047, -0.21139371]],

       [[ 1.9681001 ,  4.323698  ,  1.3546277 ],
        [ 1.3001264 ,  5.0588145 ,  0.27764672],
        [ 1.4281783 ,  3.0455346 ,  0.31784675],
        ...,
        [ 1.8273635 , -0.68548566,  1.677741  ],
        [ 1.6110762 ,  4.59672   ,  0.64524925],
        [ 1.5031681 ,  3.3782067 ,  0.9029791 ]],

       [[-4.9283485 ,  0.21323866, -2.9626484 ],
        [-1.0500737 ,  5.405244  ,  3.0227685 ],
        [ 2.08179   ,  3.9479432 , -4.770816  ],
        ...,
        [ 0.90229535,  0.52661914, -2.7231574 ],
        [-3.0022712 ,  0.19553289, -3.1069388 ],
        [-1.0290381 , -3.6024685 , -3.1178894 ]],

       ...,

       [[-3.9568725 , -2.8689969 ,  2.994326  ],
        [-2

# Extentions