In [None]:
import sys
sys.path.append("../../BayesFlow")
sys.path.append("../")

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import pickle
import time
import keras

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import bayesflow as bf
from dmc import DMC, dmc_helpers



step_size = 50

num_max_obs = 800




parent_dir = '/home/administrator/Documents/BF-LIGHT'

network_name = 'dmc_optimized_updated_priors'

model_specs_path = parent_dir + '/model_specs/model_specs_' + network_name + '.pickle'
with open(model_specs_path, 'rb') as file:
    model_specs = pickle.load(file)

simulator, adapter, inference_net, summary_net, workflow = dmc_helpers.load_model_specs(model_specs, network_name)
## Load Approximator

approximator = keras.saving.load_model(parent_dir +"/data/training_checkpoints/" + network_name + ".keras")
approximator.compile()

workflow.approximator = approximator


simulator.fixed_num_obs = num_max_obs
## Simulate Validation Data Set
val_data = simulator.sample(1000)


val_file_path = parent_dir + '/data/validation_data_metrics/validation_data_metrics_' + network_name + '.pickle'

with open(val_file_path, 'wb') as file:
    pickle.dump(val_data, file)

for key, values in val_data.items():
    print(f'{key}: {values.shape}')


# network_plot_folder = parent_dir + "/plots/metrics_num_obs/" + network_name

# if not os.path.exists(network_plot_folder):
#     os.makedirs(network_plot_folder)

list_metrics = []


with torch.enable_grad():
in contexts where you need gradients (e.g. custom training loops).
Existing checkpoints can _not_ be restored/loaded using this workflow. Upon refitting, the checkpoints will be overwritten. To load the stored approximator from the checkpoint, use approximator = keras.saving.load_model(...)
  instance.compile_from_config(compile_config)


A: (1000, 1)
tau: (1000, 1)
mu_c: (1000, 1)
mu_r: (1000, 1)
b: (1000, 1)
rt: (1000, 800, 1)
accuracy: (1000, 800, 1)
conditions: (1000, 800, 1)
num_obs: (1000, 1)


In [7]:
model_specs

{'simulation_settings': {'prior_means': array([ 16. , 111. ,   0.5, 322. ,  75. ]),
  'prior_sds': array([10.  , 47.  ,  0.13, 40.  , 23.  ]),
  'tmax': 1500,
  'contamination_probability': None,
  'min_num_obs': 50,
  'max_num_obs': 800,
  'fixed_num_obs': None},
 'inference_network_settings': {'coupling_kwargs': {'subnet_kwargs': {'dropout': 0.0100967297}},
  'depth': 10},
 'summary_network_settings': {'dropout': 0.0100967297,
  'num_seeds': 2,
  'summary_dim': 32,
  'embed_dim': (128, 128)},
 'batch_size': 16,
 'learning_rate': 0.0004916,
 'param_names': ['A', 'tau', 'mu_c', 'mu_r', 'b']}

In [4]:
val_data

{'A': array([[ 7.47226115],
        [ 8.41077568],
        [10.12338227],
        [ 7.7463453 ],
        [14.66103273],
        [15.08819532],
        [12.33148822],
        [13.60746005],
        [16.44561007],
        [18.37150887],
        [14.66812529],
        [27.05405815],
        [10.59083557],
        [18.31545803],
        [13.54453793],
        [ 3.91308745],
        [ 9.45667253],
        [12.37867924],
        [21.61032217],
        [28.63770232],
        [18.09408827],
        [19.94671327],
        [ 7.99113637],
        [ 6.66778121],
        [ 9.81879297],
        [13.70658305],
        [ 4.60762535],
        [11.86504984],
        [31.00558713],
        [12.66355034],
        [ 8.0384915 ],
        [12.24607448],
        [34.50943197],
        [ 0.72224523],
        [19.52775152],
        [16.27465185],
        [13.70963762],
        [27.12848019],
        [17.43890496],
        [17.87567862],
        [ 9.84934638],
        [ 0.28683823],
        [ 5.64096168],
      

In [None]:
def subset_data(data, num_obs, keys = ['rt', 'accuracy', 'conditions'], random=True):

    data = data.copy()

    max_obs = data[keys[0]].shape[1]

    if random:
        random_idx = np.random.choice(np.arange(0, max_obs), size=num_obs, replace=False)

        for k in keys:
            # print(f'{data[k].shape}')
            data[k] = data[k][:, random_idx, :]
            print(f'{k}: {data[k].shape} (random)')
    
    else:
        for k in keys:
            # print(f'{data[k].shape}')
            data[k] = data[k][:, :num_obs, :]
            print(f'{k}: {data[k].shape}')

In [12]:
data_subset = subset_data(val_data.copy(), num_obs=100, random=True)

In [6]:

for n_obs in np.arange(50, num_max_obs+1, step_size):
    
    print(f'num_obs: {n_obs}')
    # simulator.num_obs = n_obs

    data_subset = subset_data(val_data.copy(), num_obs=n_obs, random=False)

    start_time = time.time()
    samples = approximator.sample(conditions=data_subset, num_samples=1000)
    end_time = time.time()


    pc_df = pd.DataFrame(bf.diagnostics.metrics.posterior_contraction(samples, data_subset))

    pc_df['values'] = 1 - pc_df['values']

    ce_df = pd.DataFrame(bf.diagnostics.metrics.calibration_error(samples, data_subset))

    nrmse_df = pd.DataFrame(bf.diagnostics.metrics.root_mean_squared_error(samples, data_subset))

    results_single = pd.concat([ce_df, pc_df, nrmse_df])
    
    
    results_single["num_obs"] = n_obs
    results_single["sampling_time"] = end_time - start_time
    
    list_metrics.append(results_single)
    
data_set_metrics = pd.concat(list_metrics)

data_set_metrics.reset_index(inplace=True)


num_obs: 50
(1000, 50, 1)
(1000, 50, 1)
(1000, 50, 1)


AttributeError: 'NoneType' object has no attribute 'copy'

In [None]:

fig, axes = plt.subplots(1,5,sharey=True, figsize=(15,3))

for p, ax in zip(model_specs['param_names'], axes):
    
    suff = "$\\" if p in ["tau", "mu_c", "mu_r"] else "$"

    label = suff + p + "$"
    
    sns.lineplot(data_set_metrics[data_set_metrics["variable_names"] == p], x="num_obs", y="values", hue="metric_name", ax=ax, palette="colorblind")
    ax.set_title(label)
    ax.legend(title="")
    if p != "b":
        ax.get_legend().remove()
    
    ax.set_xlabel("")

    # plt.ylim(0, 1)

fig.tight_layout()
fig.supxlabel("Number of Observations", fontsize=12) 

fig.savefig(network_plot_folder + '/metrics_num_obs_' + network_name + '.png')


data_set_metrics_time = data_set_metrics[(data_set_metrics["metric_name"] == 'Calibration Error')]
data_set_metrics_time = data_set_metrics_time[(data_set_metrics_time["variable_names"] == 'A')]


plt.figure()

time_plot = sns.lineplot(data_set_metrics_time, x="num_obs", y="sampling_time")
    # ax.set_title(label)
    # ax.legend(title="")
    # if p != "b":
    #     ax.get_legend().remove()
    
    # ax.set_xlabel("")

    # plt.ylim(0, 1)

# fig.tight_layout()
# fig.supxlabel("Number of Observations", fontsize=12) 
time_plot_fig = time_plot.get_figure()
time_plot_fig.savefig(network_plot_folder + '/metrics_num_obs_sampling_time_' + network_name + '.png')