In [22]:
import numpy as np
import bayesflow as bf
import tensorflow as tf
from scipy.stats import halfnorm
import sys

In [23]:
sys.path.append("../src")
from observation_model import batched_dynamic_ddm
from priors import rwddm_local_prior, rwddm_hyper_prior, rwddm_shared_prior
from helpers import scale_z, unscale_z

In [24]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


# Constants

In [25]:
LOCAL_PARAM_MEANS = 1.7
LOCAL_PARAM_STDS = 1
SHARED_PRIOR_MEAN = halfnorm(0.1, 0.2).mean()
SHARED_PRIOR_STD = halfnorm(0.1, 0.2).std()
HYPER_PARAM_MEANS = halfnorm(0.0, 0.1).mean()
HYPER_PARAM_STDS = halfnorm(0.0, 0.1).std()

# Generative Model

## Prior

In [26]:
prior = bf.simulation.TwoLevelPrior(
    hyper_prior_fun=rwddm_hyper_prior,
    shared_prior_fun=rwddm_shared_prior,
    local_prior_fun=rwddm_local_prior,
)

## Simulator

In [27]:
simulator = bf.simulation.Simulator(batch_simulator_fun=batched_dynamic_ddm)

In [28]:
generative_model = bf.simulation.TwoLevelGenerativeModel(prior, simulator)

INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 150, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 150, 1)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 150, 2)
INFO:root:Shape of shared_prior_draws batch after 2 pilot simulations: (batch_size = 2, 1)
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.


# Amortizer

## Summary Network

In [36]:
summary_network = bf.networks.HierarchicalNetwork([
    tf.keras.Sequential([
        tf.keras.layers.LSTM(512, return_sequences=True),
        tf.keras.layers.LSTM(128, return_sequences=True),
    ]),
    bf.networks.TimeSeriesTransformer(128, template_dim=128, summary_dim=32)
])

## Inference Network

In [37]:
local_net = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(num_params=2,
                                  num_coupling_layers=8,
                                  coupling_design='interleaved'
                                  ))

global_net = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(num_params=2+1,
                                  num_coupling_layers=6,
                                  coupling_design='interleaved'
                                  ))

In [38]:
amortizer = bf.amortizers.TwoLevelAmortizedPosterior(local_net, global_net, summary_network)

# Configurator

In [39]:
def configure(forward_dict):
    # standardize local, sharedl, and hyper parameters
    local_prior_draws_z = scale_z(forward_dict['local_prior_draws'],
                                  LOCAL_PARAM_MEANS, LOCAL_PARAM_STDS)
    shared_prior_draws_z = scale_z(forward_dict['shared_prior_draws'],
                                   SHARED_PRIOR_MEAN, SHARED_PRIOR_STD)
    hyper_prior_draws_z = scale_z(forward_dict['hyper_prior_draws'],
                                  HYPER_PARAM_MEANS, HYPER_PARAM_STDS)

    output_dict = {
        'summary_conditions': forward_dict['sim_data'].astype(np.float32),
        'hyper_parameters': hyper_prior_draws_z.astype(np.float32),
        'local_parameters': local_prior_draws_z.astype(np.float32),
        'shared_parameters': shared_prior_draws_z.astype(np.float32)
    }
    return output_dict

# Training

In [40]:
trainer = bf.trainers.Trainer(
    generative_model=generative_model,
    configurator=configure,
    amortizer=amortizer,
    checkpoint_path='checkpoints/rwddm_shared_ndt'
    )

INFO:root:Initialized empty loss history.
INFO:root:Initialized empty simulation memory.
INFO:root:Initialized networks from scratch.
INFO:root:Performing a consistency check with provided components...
2023-06-06 09:59:46.762416: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2023-06-06 09:59:46.763504: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2023-06-06 09:59:46.7

In [41]:
h = trainer.train_online(
    epochs=50,
    iterations_per_epoch=1000,
    batch_size=32
    )

Training epoch 1:   0%|          | 0/1000 [00:00<?, ?it/s]

2023-06-06 09:59:47.943217: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2023-06-06 09:59:47.944393: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2023-06-06 09:59:47.945315: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus

Training epoch 2:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 3:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 4:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 5:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 6:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 7:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 8:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 9:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 10:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 11:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 12:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 13:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 14:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 15:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 16:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 17:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 18:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 19:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 20:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 21:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 22:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 23:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 24:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 25:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 26:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 27:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 28:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 29:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 30:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 31:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 32:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 33:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 34:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 35:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 36:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 37:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 38:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 39:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 40:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 41:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 42:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 43:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 44:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 45:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 46:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 47:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 48:   0%|          | 0/1000 [00:00<?, ?it/s]

Training epoch 49:   0%|          | 0/1000 [00:00<?, ?it/s]

In [44]:
l = bf.diagnostics.plot_losses(train_losses=trainer.loss_history.get_plottable())

# Validation

In [45]:
val_data = configure(generative_model(1))

In [47]:
val_data

## Local Params

## Hyper Params

# Empirical Data

## Martins Data

In [32]:
# Data preparation

## Average Parameter Dynamic

## Absolute Fit

## Other Dataset

In [None]:
# Data preparation

## Average Parameter Dynamic

## Absolute Fit