# Imports

In [80]:
import numpy as np
np.set_printoptions(suppress=True)

In [103]:
from bayesflow.forward_inference import *
from bayesflow.configuration import *
from bayesflow.amortized_inference import AmortizedPosterior
from bayesflow.networks import InvertibleNetwork, InvariantNetwork
from bayesflow.trainers import Trainer

In [82]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# User-defined functions

In [83]:
def prior_fun():
    return np.random.normal(size=3)

def n_obs_gen():
    return np.random.randint(20, 100+1)

def generator(params, n_obs):
    return np.random.normal(params, size=(n_obs, params.shape[0]))

# Generative Model Interface

In [84]:
sim_context = ContextGenerator(
    non_batchable_context_fun=n_obs_gen,
)

In [85]:
sim = Simulator(simulator_fun=generator, context_generator=sim_context)

In [86]:
gen_model = GenerativeModel(prior_fun, sim)

Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 3)
Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 58, 3)
No prior non-batchable context provided.
No prior batchable context provided.
Could not determine shape of simulation non-batchable context. Type appears to be non-array: <class 'int'>,                                    so make sure your input configurator takes cares of that!
No simulation batchable context provided.


# Networks and Amortizer

In [104]:
summary_net = InvariantNetwork()
inference_net = InvertibleNetwork({'n_params': 3})
amortizer = AmortizedPosterior(inference_net, summary_net, summary_loss_fun='mmd')

# Trainer and Simulation-Based Training

In [106]:
# change var_obs
trainer = Trainer(amortizer=amortizer, generative_model=gen_model, configurator='var_obs')

Performing consistency check with provided modules...Done.


In [108]:
h = trainer.train_online(epochs=1, iterations_per_epoch=100, batch_size=32)

Training epoch 1:   0%|          | 0/100 [00:00<?, ?it/s]

# Some Checks

In [91]:
inp_data = trainer.configurator(trainer.generative_model(32))

In [100]:
amortizer.sample(inp_data, 10)

array([[[ 2.1646936 , -0.21531303,  5.5063796 ],
        [ 1.6260811 , -1.3071998 ,  9.322541  ],
        [ 3.060134  , -0.46829444,  0.7760233 ],
        [ 1.5900848 , -1.5491301 ,  7.372608  ],
        [ 0.06635528,  1.931855  ,  0.1290187 ],
        [ 2.326499  ,  0.37538654,  7.810817  ],
        [ 2.744444  ,  0.8728799 ,  5.480768  ],
        [ 2.9325137 ,  2.6140668 ,  2.1806927 ],
        [ 1.6944929 , -0.06102439,  9.003493  ],
        [ 2.4327252 , -2.8572407 ,  6.644378  ]],

       [[ 2.6241312 , -1.8137338 ,  6.3449836 ],
        [ 4.23963   , -1.4367161 ,  4.029674  ],
        [ 4.909225  ,  1.074682  ,  3.615526  ],
        [ 1.974734  , -1.0667816 ,  9.2077875 ],
        [ 7.2513084 ,  1.0038848 ,  2.0818434 ],
        [ 3.3621936 , -0.9641121 ,  6.1516933 ],
        [ 0.2385915 ,  0.23268044,  6.4222445 ],
        [ 3.3396232 , -2.9981494 ,  2.3405478 ],
        [ 1.2197895 , -1.5217446 ,  7.5101104 ],
        [ 2.2429206 ,  1.2091366 ,  7.037413  ]],

       [[ 3.1014

In [101]:
amortizer.log_posterior(inp_data)

array([ -9.982149 ,  -9.19886  ,  -9.189264 ,  -7.6114187,  -8.850998 ,
        -9.52586  ,  -8.6193075,  -8.136214 ,  -8.723217 ,  -7.7660174,
        -9.705035 ,  -9.256691 ,  -8.379083 , -10.392677 ,  -6.6298227,
        -8.636785 ,  -9.133993 , -12.507931 , -10.862332 ,  -6.6702795,
        -8.614443 ,  -8.6319895,  -9.378907 , -10.925872 ,  -9.41332  ,
        -8.34085  ,  -8.560421 ,  -7.5433216,  -8.812265 ,  -8.915992 ,
        -8.41679  ,  -8.928123 ], dtype=float32)