In [None]:
import os     
if "KERAS_BACKEND" not in os.environ:
    #os.environ["KERAS_BACKEND"] = "tensorflow"  
    os.environ["KERAS_BACKEND"] = "jax"  
import bayesflow as bf
import sys
sys.path.append('../')
import keras
import matplotlib.pyplot as plt
import numpy as np
import numba as nb
from numba import njit
import math
import ipynbname
from pathlib import Path
RNG = np.random.default_rng(2025)

In this exercise, you need to load networks trained with standard DDM data, and networks trained with contaminated DDM data. Then you test the robustness of these two networks and do a small comparison.

As we want to load the trained models, it is important to know where your notebook and models (.keras) are stored. The following code helps 
you to locate them. If you already know your path, you can define it manually as well.

In [57]:
#notebook_path = your_path
notebook_path = Path(ipynbname.path()).resolve().parent
print(notebook_path)

C:\Windows\System32\bayesflow


## Introduction

This tutorial focuses on:
1. The simulation and inference of stochastic cognitive models in BayesFlow
2. Testing and improving the robustness of inference in BayesFlow

Here we fit Drift Diffusion Model (DDM), a popular stochastic model in decision-making field.

### Drift Diffusion Model

Below is a graphical illustration of the drift diffusion process and the resulting reaction time data in a hypothetical visual recognition memory task. Participants view an image and judge whether it is old (i.e., previously seen) or new (i.e., previously unseen). The rows in the resulting data table correspond to individual trials of the experiment, with conditions and responses coded as 0 (old image) and 1 (new image).

![Diffusion Model Plot](plots/Diffusion_Model_plot.png)


The standard DDM assume that the decision-making process is a evidence accumulation process, after the accumulated evidence reaches a certain boundary, a decision is made. The upper and lower boundaries corresponds to different choices. It thus has four types of parameters:
1. Drift rate ($v$): the average rate of evidence accumulation under certain condtion
2. Boundary separation ($a$): the distance between two boundaries
3. Response bias ($z$): the relative starting point of the diffusion process
4. Non-decision time ($ndt$): any time unrelated to the decision process itself (e.g., stimulus encoding, motor execution)

### Sensitivity to Outliers
DDM is sensitive to outliers due to the nature of its assumptions. The key parameter non-decision time ($ndt$) is estimated to be lower than the shortest reaction time in the data set by design. Since the decision process is jointly determined by all DDM parameters, when a short outlier is present, it can distort not only the estimate of $T_{er}$ but also those of other parameters, leading to biased results Consequently, addressing the influence of outliers has been a persistent challenge in DDM fitting. We thus use DDM as an example to test the robustness of amortized Bayesian inference, and try to improve the robustness.

## A Standard DDM Estimator
We first train a standard DDM estimator, that is, to simulate standard reaction time data from DDM to train neural networks.
### Single trial simulation
Here we simulate one single diffusion trial.

In [58]:
@nb.jit(nopython=True, cache=True)
def diffusion_trial(v, a, z, ndt, dt=1e-3, max_steps=15000):
    """Simulates a trial from the diffusion model."""
    n_steps = 0
    x = a * z
    mu = v * dt
    sigma = math.sqrt(dt)
    
    # Simulate a single DM path
    for n_steps in range(max_steps):
        # DDM equation
        x += mu + sigma * np.random.normal(0,1)
        # Stop when out of bounds
        if x <= 0.0 or x >= a:
            break
    
    rt = float(n_steps) * dt
    
    if x > 0:
        resp = 1.
    else:
        resp = 0.
    return rt+ndt,resp

In [None]:
diffusion_trial(v = 1, a = 2, z = 0.5, ndt = 0.3)

### Prior
We define a prior that is wide enough to cover a realistic range of true parameters. Two drift rates are assumed, corresponding to two conditions in the experimental design.

In [None]:
def diffusion_prior():
    "Generates a random draw from the joint prior distribution."
    #normal distribution for the drift rates
    v_1 = RNG.uniform(-7,7)
    v_2 = RNG.uniform(-7,7)
    a = RNG.uniform(0.5,5)
    ndt = RNG.gamma(1.5, 1 / 5.0)
    z = RNG.uniform(.01,.99)
    return dict(v = np.array((v_1, v_2)),
                a = a,
                ndt = ndt,
                z = z)

In [None]:
diffusion_prior()

### Number of Observations
The number of observations (trials) in each batch is randomly sampled.

In [None]:
#design matrix
min_obs = 100
max_obs = 500
def meta(batch_size, num_obs = None):

    if num_obs == None:
        num_obs = np.random.randint(min_obs, max_obs)
    return dict(num_obs = num_obs)
    #return dict(num_obs = 200)

### Observational Model
We wrap the prior, single trial simulator, and number of observations up to obtain a complete model. 

In [None]:
def diffusion_experiment(v, a, z, ndt, num_obs, rng=None, *args):
    out = np.zeros((num_obs, 2))

    #create an array with condition (dummy variable, with two values "0" and "1")
    num_conditions = 2
    counts = np.random.multinomial(num_obs, [1/num_conditions] * num_conditions)
    condition = np.concatenate([np.full(count, i) for i, count in enumerate(counts)])
    np.random.shuffle(condition)
    
    for n in range(num_obs):
        index = condition[n]
        rt,resp = diffusion_trial(v[index], a, z, ndt)
        out[n, :] = np.array([rt,resp])

    #log transform the reaction time for faster convergence
    out[:,0] = np.log(out[:,0])
    
    return dict(rt = out[:,0], resp = out[:,1], con = condition)

In [None]:
simulator = bf.simulators.make_simulator([diffusion_prior, diffusion_experiment], meta_fn=meta)

In [None]:
sim_data = simulator.sample(200)
print("Number of observations:", sim_data["num_obs"])
print("Shape of rt:", sim_data["rt"].shape)
print("Shape of response:", sim_data["resp"].shape)
print("Shape of condition:", sim_data["con"].shape)

In [None]:
par_keys = ["v", "a", "ndt", "z"]
par_names = [r"v_1", r"v_2", "a", "ndt", "z"]

### Adapter

In [None]:
adapter = (
    bf.Adapter()
    .broadcast("num_obs", to="rt")
    .concatenate(["v", "a","ndt","z"], into = "inference_variables")
    .as_set(["rt", "resp", "con"])
    .concatenate(["rt","resp", "con"], into = "summary_variables")
    .rename("num_obs", "inference_conditions"))

### Networks

In [None]:
summary_network = bf.networks.SetTransformer(summary_dim=10)
inference_network = bf.networks.CouplingFlow()

If you want to train a new model, you can define ```checkpoint_filepath =``` and ```checkpoint_name =``` to store and name your trained model in the ```BasicWorkflow``` object. They are now commented out. 

In [None]:
workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=inference_network,
    summary_network=summary_network,
    initial_learning_rate=5e-4,
    optimizer = optimizer,
    summary_variables = ["rt","resp", "con"],
    inference_variables = ["drifts","threshold","ndt","z"],
    inference_conditions = ["num_obs"],
    #checkpoint_filepath = checkpoint_path,
    #checkpoint_name = "standard_ddm"
)

As we already have the trained model, we just need to load it with ```keras.saving.load_model```.

In [None]:
#define the path where you store the keras
checkpoint_path_standard = notebook_path/"standard_ddm.keras"
workflow.approximator = keras.saving.load_model(checkpoint_path_standard)

In case you want to train the network by yourself, then the following setting can be used.

In [None]:
# epochs = 200
# num_batches = 128
# batch_size = 32
# learning_rate = keras.optimizers.schedules.CosineDecay(5e-4, decay_steps=epochs*num_batches, alpha=1e-6)
# optimizer = keras.optimizers.Adam(learning_rate=learning_rate, clipnorm=1.0)


# history = workflow.fit_online(
#      epochs=epochs,
#      batch_size=batch_size,
#      num_batches_per_epoch=num_batches
#  )

## Robust Estimator
We then define a robust estimator, where a small amount of data is assumed to be contaminated.

In [None]:
def diffusion_experiment_robust(v, a, z, ndt, num_obs, rng=None, *args):
    out = np.zeros((num_obs, 2))

    #create an array with condition (dummy variable, with two values "0" and "1")
    num_conditions = 2
    counts = np.random.multinomial(num_obs, [1/num_conditions] * num_conditions)
    condition = np.concatenate([np.full(count, i) for i, count in enumerate(counts)])
    np.random.shuffle(condition)
    
    for n in range(num_obs):
        index = condition[n]
        rt,resp = diffusion_trial(v[index], a, z, ndt)
        out[n, :] = np.array([rt,resp])
        
    out[:,0] = np.log(out[:,0])

    contaminants_rt = np.abs(np.random.standard_t(df=1, size=num_obs))
    contaminants_resp = np.random.binomial(n=1,p=0.5,size=num_obs)
    
    replace = np.random.binomial(n=1, p=.1, size=num_obs)
    
    out[:,0] = (1-replace)*out[:,0] + (replace)*np.log(contaminants_rt)
    out[:,1] = (1-replace)*out[:,1] + (replace)*contaminants_resp
    
    return dict(rt = out[:,0], resp = out[:,1], con = condition)

In [None]:
simulator_robust = bf.simulators.make_simulator([diffusion_prior, diffusion_experiment_robust], meta_fn=meta)
summary_network_robust = bf.networks.SetTransformer(summary_dim=15)
inference_network_robust = bf.networks.CouplingFlow(transform="spline")

In [None]:
workflow_robust = bf.BasicWorkflow(
    simulator=simulator_robust,
    adapter=adapter,
    inference_network=inference_network_robust,
    summary_network=summary_network_robust,
    initial_learning_rate=5e-4,
    optimizer = optimizer,
    summary_variables = ["rt","resp", "con"],
    inference_variables = ["drifts","threshold","ndt","z"],
    inference_conditions = ["num_obs"],
    #checkpoint_filepath = checkpoint_path,
    #checkpoint_name = "robust_ddm"
)

In [None]:
# epochs_robust = 400
# history = workflow_robust.fit_online(
#      epochs=epochs_robust,
#      batch_size=batch_size,
#      num_batches_per_epoch=num_batches
#  )

In [None]:
checkpoint_path_robust = notebook_path/"robust_ddm.keras"
workflow_robust.approximator = keras.saving.load_model(checkpoint_path_robust)