In [None]:
%load_ext autoreload
%autoreload 2

import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.utils import read_from_file, save_to_file
from tqdm import tqdm

In [None]:
def save_as_chunks(test_dataset, save_dir, chunk_size=5000):
    """
    Splits test_dataset into chunks and saves each chunk as a separate file.

    If test_dataset is a dictionary, it will chunk each value across keys consistently.
    """

    def get_length(obj):
        if isinstance(obj, dict):
            return len(next(iter(obj.values())))
        return len(obj)

    def get_chunk(obj, start, end):
        if isinstance(obj, dict):
            return {k: v[start:end] for k, v in obj.items()}
        return obj[start:end]

    num_chunks = (
        get_length(test_dataset) + chunk_size - 1
    ) // chunk_size  # Ceiling division

    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, get_length(test_dataset))
        chunk = get_chunk(test_dataset, start, end)
        file_path = save_dir / f"test_dataset_chunk_{i + 1}.pkl"
        save_to_file(chunk, file_path)

Prepare the test datasets for different tasks.

## GEV

In [None]:
import pymc as pm
import pymc_extras.distributions as pmx
from sbi_mcmc.tasks import PyMCTask


# GEV with wider priors
class GeneralizedExtremeValueWide(PyMCTask):
    def __init__(self):
        var_names = ["mu", "sigma", "xi"]
        task_name = "GEV"
        super().__init__(
            var_names=var_names,
            task_name=task_name,
        )

    def setup_pymc_model(self, observation=None) -> pm.Model:
        if observation is None:
            observation = self.get_observation()
        p = 1 / 10
        with pm.Model() as pymc_model:
            # Priors
            mu = pm.Normal("mu", mu=3.8, sigma=0.4)
            sigma = pm.HalfNormal("sigma", sigma=0.6)
            xi = pm.TruncatedNormal(
                "xi", mu=0, sigma=0.4, lower=-1.2, upper=1.2
            )

            # Estimation
            gev = pmx.GenExtreme(
                "gev",
                mu=mu,
                sigma=sigma,
                xi=xi,
                observed=observation,
            )
            # Return level
            z_p = pm.Deterministic(
                "z_p", mu - sigma / xi * (1 - (-np.log(1 - p)) ** (-xi))
            )
        return pymc_model

    def get_observation(self):
        # fmt: off
        data = np.array([4.03, 3.83, 3.65, 3.88, 4.01, 4.08, 4.18, 3.80,
                        4.36, 3.96, 3.98, 4.69, 3.85, 3.96, 3.85, 3.93,
                        3.75, 3.63, 3.57, 4.25, 3.97, 4.05, 4.24, 4.22,
                        3.73, 4.37, 4.06, 3.71, 3.96, 4.06, 4.55, 3.79,
                        3.89, 4.11, 3.85, 3.86, 3.86, 4.21, 4.01, 4.11,
                        4.24, 3.96, 4.21, 3.74, 3.85, 3.88, 3.66, 4.11,
                        3.71, 4.18, 3.90, 3.78, 3.91, 3.72, 4.00, 3.66,
                        3.62, 4.33, 4.55, 3.75, 4.08, 3.90, 3.88, 3.94,
                        4.33])
        # fmt: on
        return data


task = GeneralizedExtremeValueWide()
paths = get_paths(task)
test_dataset = task.sample(1000)

save_to_file(
    test_dataset,
    paths["dataset_dir"] / "test_dataset_chunk_1.pkl",
)

## Bernoulli GLM 

In [None]:
task = BernoulliGLMTask()
paths = get_paths(task)

test_dataset = task.sample(10000)

save_to_file(
    test_dataset,
    paths["dataset_dir"] / "test_dataset.pkl",
)

In [None]:
test_dataset = read_from_file(paths["dataset_dir"] / "test_dataset.pkl")
save_as_chunks(test_dataset, paths["dataset_dir"])

## Psychometric Task


In [None]:
from brainbox.behavior.training import (
    compute_performance,
    compute_psychometric,
    get_signed_contrast,
)
from brainbox.io.one import SessionLoader
from one.api import ONE

ONE.setup(base_url="https://openalyx.internationalbrainlab.org", silent=True)
one = ONE(password="international")  # noqa

In [None]:
from sbi_mcmc.tasks import PsychometricTask

task = PsychometricTask(overdispersion=True)
paths = get_paths(task)

We provided the csv directly so downloading and pre-processing is not necessary.

In [None]:
# # Download the data
# eid = "4ecb5d24-f5cc-402c-be28-9d0f7cb14b3a"
# trials = one.load_object(eid, "trials", collection="alf")
# eids = one.search(task="biasedChoiceWorld")
# print(len(eids))
# errors = []
# for eid in eids:
#     try:
#         sl = SessionLoader(eid=eid, one=one)
#         sl.load_trials()
#     except Exception as e:
#         errors.append((eid, e))
# print(errors)

In [None]:
# # Create a DataFrame
# eids = one.search(task="biasedChoiceWorld")
# data = []
# for i in tqdm(range(len(eids))):
#     eid = eids[i]
#     try:
#         sl = SessionLoader(eid=eid, one=one)
#         sl.load_trials()
#     except Exception as e:
#         print(e)
#         continue
#     trials = sl.trials
#     for block in [0.5, 0.2, 0.8]:
#         signed_contrast = get_signed_contrast(trials)
#         trials["signed_contrast"] = signed_contrast
#         performance, contrasts, n_contrasts = compute_performance(
#             trials,
#             signed_contrast=signed_contrast,
#             block=block,
#             prob_right=True,
#         )
#         if np.isnan(n_contrasts).any():
#             continue
#         data.append(
#             {
#                 "eid": eid,
#                 "block": block,
#                 "performance": performance.tolist(),
#                 "contrasts": contrasts.tolist(),
#                 "n_contrasts": n_contrasts.tolist(),
#             }
#         )
# # Convert to DataFrame
# df = pd.DataFrame(data)

# # Save as CSV
# csv_path = paths["dataset_dir"] / "psychometric_data.csv"
# df.to_csv(csv_path, index=False)

In [None]:
import ast

import numpy as np

csv_path = task.task_info_dir / "psychometric_data.csv"
# Read the csv file
df = pd.read_csv(
    csv_path,
    converters={
        "performance": ast.literal_eval,
        "n_contrasts": ast.literal_eval,
        "contrasts": ast.literal_eval,
    },
)
df["trial_count"] = df["n_contrasts"].apply(sum)
# Convert lists to numpy arrays and multiply elementwise
df["right_count"] = df.apply(
    lambda row: [
        round(p * n)
        for p, n in zip(row["performance"], row["n_contrasts"], strict=True)
    ],
    axis=1,
)
df = df[["contrasts", "n_contrasts", "right_count"]]

# Convert lists to tuples to make them hashable and count occurrences
unique_contrasts = df["contrasts"].apply(tuple)
unique_counts = unique_contrasts.value_counts()
print(unique_counts)

# For simplicity, only keep rows with contrasts=(-100.0, -25.0, -12.5, -6.25, 0.0, 6.25, 12.5, 25.0, 100.0)
df = df[
    df["contrasts"].apply(tuple)
    == (-100.0, -25.0, -12.5, -6.25, 0.0, 6.25, 12.5, 25.0, 100.0)
]

In [None]:
# Stack the columns of df and convert to numpy array of shape (N_dataset, N_contrasts, 3)
# Each dataset contain [contrast, n_trials, n_right] for each contrast level
test_dataset = np.stack(
    [
        df["contrasts"].tolist(),
        df["n_contrasts"].tolist(),
        df["right_count"].tolist(),
    ],
    axis=-1,
)
test_dataset[..., 0] /= 100  # scale the contrast to be between -1 and 1
save_to_file(test_dataset, paths["dataset_dir"] / "test_dataset.pkl")

In [None]:
from sbi_mcmc.utils.utils import read_from_file

test_dataset = read_from_file(paths["dataset_dir"] / "test_dataset.pkl")
print(test_dataset.shape)

In [None]:
save_as_chunks(test_dataset, paths["dataset_dir"])

## DDM task

In [None]:
from sbi_mcmc.tasks.ddm import CustomDDM

task = CustomDDM(dt=0.0001)
paths = get_paths(task)

In [None]:
data = pd.read_csv(task.task_info_dir / "stan_data.csv").rename(
    columns={"id": "subject"}
)
ids = data["subject"].unique()
print(len(ids))


def process_data_block(data, condition_val=0):
    """Process a block of data (congruent or incongruent) into required format.

    Args:
        data: DataFrame with rt and response columns
        condition_val: 0 for congruent, 1 for incongruent condition

    Returns:
        Processed numpy array with shape (60, 4) containing:
        [rt, missing flag, condition type, response]
    """
    missing = np.zeros(60)
    condition_type = np.ones(60) * condition_val
    data = data[["rt", "response"]].values
    if data.shape[0] < 60:
        missing[data.shape[0] :] = 1
        data = np.concatenate(
            [data, np.zeros((60 - data.shape[0], 2))], axis=0
        )

    return np.stack([data[:, 0], missing, condition_type], axis=1)


test_dataset = []
num_missing_trials = {}
for subject_id in tqdm(ids[:]):
    data_1p = data[(data["subject"] == subject_id) & (data["rt"] > 0)]
    # Replace response 0 with -1
    data_1p.loc[:, "response"] = data_1p["response"].replace(0, -1)
    data_1p.loc[:, "rt"] = data_1p["rt"] * data_1p["response"]
    data_c = data_1p[data_1p["block"] == 1]  # congruent
    data_i = data_1p[data_1p["block"] == 0]  # incongruent

    # Process congruent and incongruent blocks
    data_c = process_data_block(data_c, condition_val=0)
    data_i = process_data_block(data_i, condition_val=1)

    # Concatenate the data
    observation = np.concatenate([data_c, data_i], axis=0)
    if np.sum(observation[:, 1]) > 0:
        num_missing_trials[subject_id] = np.sum(observation[:, 1])

    # code stimulus types, picture == 1
    stimulus_type = np.concatenate(
        (
            np.zeros(30),
            np.ones(30),  # condition 1: congruent
            np.zeros(30),
            np.ones(30),
        )
    )  # condition 2: incongruent
    observation = np.concatenate(
        [observation, stimulus_type[:, None]], axis=-1
    )
    test_dataset.append(observation)

# print(f"Number of missing trials: \n{num_missing_trials}")
print(f"{len(num_missing_trials)} subjects with missing trials")
print(f"max missing trials: {max(num_missing_trials.values())}")
test_dataset = np.stack(test_dataset, axis=0)
print(test_dataset.shape)

In [None]:
# Save the test dataset
save_to_file(test_dataset, paths["dataset_dir"] / "test_dataset.pkl")

In [None]:
import numpy as np
from sbi_mcmc.utils.utils import read_from_file

test_dataset = read_from_file(paths["dataset_dir"] / "test_dataset.pkl")
save_as_chunks(test_dataset, paths["dataset_dir"])

Below was used for generating additional chunks of data for the DDM task. The data was provided by the original authors. We used `test_dataset_chunk_3` and `test_dataset_chunk_4`, together with the `test_dataset_chunk_1` above.

In [None]:
# import os
# import pickle

# import numpy as np

# file_path = os.path.expanduser("~/Downloads/emp_data.p")
# with open(file_path, "rb") as f:
#     result = pickle.load(f)
# # Create a mask to identify people who meet the condition
# valid_people_mask = np.all(
#     (np.abs(result["data_array"][..., 0]) <= 20)
#     & (np.abs(result["data_array"][..., 0]) > 0),
#     axis=1,
# )

# # Filter the data_array to keep only valid people
# filtered_data_array = result["data_array"][valid_people_mask]

# reordered_data_array = np.zeros_like(filtered_data_array)

# for i in range(filtered_data_array.shape[0]):  # For each person
#     # Get indices that would sort the values in dimension 1
#     sort_indices = np.argsort(filtered_data_array[i, :, 1])

#     # Reorder the data for this person
#     reordered_data_array[i] = filtered_data_array[i, sort_indices]

# # Verify that 1s come before 0s for each person
# assert np.all(reordered_data_array[:, :60, 1] == 0)
# assert np.all(reordered_data_array[:, 60:, 1] == 1)

# missing = np.zeros_like(reordered_data_array)[..., 0]
# stimulus_type = np.concatenate(
#     (
#         np.zeros(30),
#         np.ones(30),  # condition 1: congruent
#         np.zeros(30),
#         np.ones(30),
#     )
# )  # condition 2: incongruent
# stimulus_type = np.tile(stimulus_type, (reordered_data_array.shape[0], 1))

# data = np.stack(
#     [
#         reordered_data_array[..., 0],
#         missing,
#         reordered_data_array[..., 1],
#         stimulus_type,
#     ],
#     axis=-1,
# )


# save_as_chunks(
#     data[: 22 * 5000],
#     Path("ddm_data"),
# )