In [13]:
import sys
sys.path.append("../../BayesFlow")
sys.path.append("../")

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import pickle

import keras

import optuna

Function Trials:

In [14]:
import bayesflow as bf
from dmc import DMC

In [15]:
simulator = DMC(
    prior_means=np.array([16., 111., 0.5, 322., 75.]), 
    prior_sds=np.array([10., 47., 0.13, 40., 23.]),
    tmax=1500,
)

Experiment Function:

In [16]:
adapter = (
    bf.adapters.Adapter()
    .convert_dtype("float64", "float32")
    .sqrt("num_obs")
    .concatenate(["A", "tau", "mu_c", "t0", "b"], into="inference_variables")
    .concatenate(["rt", "accuracy", "conditions"], into="summary_variables")
    .standardize(include="inference_variables")
    .rename("num_obs", "inference_conditions")
)

## Load Training and Validation Data

In [17]:
training_file_path = '../data/data_offline_training/data_offline_training.pickle'

with open(training_file_path, 'rb') as file:
    train_data = pickle.load(file)

    
val_file_path = '../data/data_offline_training/data_offline_validation.pickle'

with open(val_file_path, 'rb') as file:
    val_data = pickle.load(file)


In [63]:
def weighted_metric_sum(metrics_table, weight_recovery=1, weight_pc=1, weight_sbc=1):
    
    # recode posterior contraction
    metrics_table.iloc[1,:]=1-metrics_table.iloc[1,:]

    # compute means across parameters
    metrics_means=metrics_table.mean(axis=1)

    # decide on weights for each metric (Recovery, Posterior Contraction, SBC)
    metrics_weights=np.array([weight_recovery, weight_pc, weight_sbc])

    # compute weighted sum
    weighted_sum=np.dot(metrics_means, metrics_weights)
    
    return weighted_sum

In [64]:
### define objective function

def objective(epochs=1):

    # Optimize hyperparameters
    # dropout = trial.suggest_float("dropout", 0.01, 0.5)
    # initial_learning_rate = trial.suggest_float("lr", 1e-4, 1e-3) 
    
    dropout = 0.1
    initial_learning_rate = 5e-4
    batch_size=128
    
    # Create inference net
    
    inference_net = bf.networks.CouplingFlow(coupling_kwargs=dict(subnet_kwargs=dict(dropout=dropout)))

    # inference_net = bf.networks.FlowMatching(subnet_kwargs=dict(dropout=0.1))

    summary_net = bf.networks.SetTransformer(summary_dim=32, num_seeds=2, dropout=0.1)
    
    
    workflow = bf.BasicWorkflow(
        simulator=simulator,
        adapter=adapter,
        initial_learning_rate=initial_learning_rate,
        inference_network=inference_net,
        summary_network=summary_net,
        # checkpoint_filepath='../checkpoints',
        # checkpoint_name= "simons_crazy_net3",
        inference_variables=["A", "tau", "mu_c", "t0", "b"])
    
    history = workflow.fit_offline(train_data, epochs=epochs, batch_size=batch_size, validation_data=val_data)
    
    metrics_table=workflow.compute_default_diagnostics(test_data=val_data)

    # compute weighted sum
    weighted_sum=weighted_metric_sum(metrics_table)
    
    # loss=np.mean(history.history["val_loss"][-5:])
        
    return weighted_sum

objective_test=objective()



INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.


[1m 83/391[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m20s[0m 66ms/step - loss: 6.4618 - loss/inference_loss: 6.4618

KeyboardInterrupt: 

In [None]:
objective_test

0.8598938856409903

In [72]:

def objective(trial, epochs=50):

    # Optimize hyperparameters
    dropout = trial.suggest_float("dropout", 0.01, 0.3)
    initial_learning_rate = trial.suggest_float("lr", 1e-4, 1e-3) 
    num_seeds=trial.suggest_int("num_seeds", 1, 4)
    depth=trial.suggest_int("depth", 5, 10)
    
    batch_size=128
    
    # Create inference net
    
    inference_net = bf.networks.CouplingFlow(coupling_kwargs=dict(subnet_kwargs=dict(dropout=dropout)), depth=depth)

    # inference_net = bf.networks.FlowMatching(subnet_kwargs=dict(dropout=0.1))

    summary_net = bf.networks.SetTransformer(summary_dim=32, num_seeds=num_seeds, dropout=dropout)
    
    
    workflow = bf.BasicWorkflow(
        simulator=simulator,
        adapter=adapter,
        initial_learning_rate=initial_learning_rate,
        inference_network=inference_net,
        summary_network=summary_net,
        # checkpoint_filepath='../checkpoints',
        # checkpoint_name= "simons_crazy_net3",
        inference_variables=["A", "tau", "mu_c", "t0", "b"])
    
    history = workflow.fit_offline(train_data, epochs=epochs, batch_size=batch_size, validation_data=val_data, verbose=0)
    
    metrics_table=workflow.compute_default_diagnostics(test_data=val_data)

    # compute weighted sum
    weighted_sum=weighted_metric_sum(metrics_table)
    
    # loss=np.mean(history.history["val_loss"][-5:])
        
    return weighted_sum

In [None]:
study = optuna.create_study(direction="minimize")

study.optimize(objective, n_trials=40)

[I 2025-04-02 18:57:01,588] A new study created in memory with name: no-name-3df42576-5ff5-4a51-863e-3aa99e7bac20
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.


In [None]:
trial = study.best_trial
print("Validation loss: {}".format(trial.value))
print("Best hyperparameters: {}".format(trial.params))

INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.


Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 124ms/step - loss: 4.7947 - loss/inference_loss: 4.7947 - val_loss: 4.2981 - val_loss/inference_loss: 4.2981
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 124ms/step - loss: 2.9251 - loss/inference_loss: 2.9251 - val_loss: 3.0569 - val_loss/inference_loss: 3.0569
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 125ms/step - loss: 2.1602 - loss/inference_loss: 2.1602 - val_loss: 2.1239 - val_loss/inference_loss: 2.1239
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 126ms/step - loss: 1.6297 - loss/inference_loss: 1.6297 - val_loss: 1.6652 - val_loss/inference_loss: 1.6652
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 126ms/step - loss: 1.3089 - loss/inference_loss: 1.3089 - val_loss: 2.3758 - val_loss/inference_loss: 2.3758
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m

KeyboardInterrupt: 