In [20]:
import numpy as np
import bayesflow as bf
import tensorflow as tf
import sys

In [21]:
sys.path.append("../src")
from observation_model import batched_dynamic_ddm
from priors import rwddm_local_prior, rwddm_hyper_prior, rwddm_shared_prior
from helpers import scale_z, unscale_z

In [22]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

IndexError: list index out of range

# Generative Model

## Prior

In [14]:
prior = bf.simulation.TwoLevelPrior(
    hyper_prior_fun=rwddm_hyper_prior,
    shared_prior_fun=rwddm_shared_prior,
    local_prior_fun=rwddm_local_prior,
)

## Simulator

In [16]:
simulator = bf.simulation.Simulator(batch_simulator_fun=batched_dynamic_ddm)

In [17]:
generative_model = bf.simulation.TwoLevelGenerativeModel(prior, simulator)

INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 150, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 150, 1)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 150, 2)
INFO:root:Shape of shared_prior_draws batch after 2 pilot simulations: (batch_size = 2, 1)
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.


# Amortizer

## Summary Network

In [18]:
summary_network = bf.networks.HierarchicalNetwork([
    tf.keras.Sequential([
        tf.keras.layers.LSTM(512, return_sequences=True),
        tf.keras.layers.LSTM(128, return_sequences=True),
    ]),
    bf.networks.TimeSeriesTransformer(128, template_dim=128, summary_dim=32)
])

## Inference Network

In [None]:
local_net = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(num_params=2,
                                  num_coupling_layers=8,
                                  coupling_design='interleaved'
                                  ))

global_net = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(num_params=2+1,
                                  num_coupling_layers=6,
                                  coupling_design='interleaved'
                                  ))

In [None]:
amortizer = bf.amortizers.TwoLevelAmortizedPosterior(local_net, global_net, summary_network)

# Configurator

In [None]:
def configure(forward_dict):
    # standardize local and hyper parameters
    local_prior_draws_z = scale_z(forward_dict['local_prior_draws'],
                                  LOCAL_PARAM_MEANS, LOCAL_PARAM_STDS)
    shared_prior_draws_z = scale_z(forward_dict['shared_prior_draws'],
                                   SHARED_PRIOR_MEAN, SHARED_PRIOR_STD)
    hyper_prior_draws_z = scale_z(forward_dict['hyper_prior_draws'],
                                  HYPER_PARAM_MEANS, HYPER_PARAM_STDS)

    output_dict = {
        'summary_conditions': forward_dict['sim_data'].astype(np.float32),
        'hyper_parameters': hyper_prior_draws_z.astype(np.float32),
        'local_parameters': local_prior_draws_z.astype(np.float32),
        'shared_parameters': shared_prior_draws_z.astype(np.float32)
    }
    return output_dict

# Training

In [None]:
trainer = bf.trainers.Trainer(
    generative_model=generative_model,
    configurator=configure,
    amortizer=amortizer
    )

In [None]:
h = trainer.train_online(
    epochs=50,
    iterations_per_epoch=1000,
    batch_size=32,
    checkpoint_path='checkpoints/rwddm_shared_ndt'
    )

In [None]:
l = bf.diagnostics.plot_losses(train_losses=trainer.loss_history.get_plottable())