In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import requests
from scipy.io import loadmat
from scipy.optimize import minimize
from scipy.stats import norm


In [None]:
url = "https://labshare.cshl.edu/shares/library/repository/38957/lapseDataset2020.mat"
filename = "lapse_data.mat"

if not os.path.exists(filename):
    r = requests.get(url)
    with open(filename, "wb") as f:
        f.write(r.content)

mat_dict = loadmat(filename)
mat_data = mat_dict["dataset"]


In [None]:
def mat_to_numpy_dict(mat_data, index=0, verbose=True):
    """Convert data from MATLAB to a dictionary of NumPy arrays."""
    if mat_data.dtype.names is None:
        if verbose:
            print("\t" * index, mat_data.shape)
        if mat_data.shape[0] == 1:
            return mat_data[0]
        else:
            return [mat_data[i] for i in range(mat_data.shape[0])]
    else:
        res = {}
        for name in mat_data.dtype.names:
            shape = mat_data[name].shape
            if verbose:
                print("\t" * index, name, mat_data[name].shape)
            if len(shape) != 2:
                raise RuntimeError(f"Unexpected shape for {name}: {shape}.")
            if shape[0] == 1:
                res[name] = mat_to_numpy_dict(mat_data[name][0][0], index + 1, verbose)
                if shape[1] > 1:
                    res[name] = [res[name]]
                    remaining = [
                        mat_to_numpy_dict(mat_data[name][0, i], index + 1, False)
                        for i in range(1, shape[1])
                    ]
                    res[name].extend(remaining)
            else:
                raise RuntimeError(f"Unexpected shape for {name}: {shape}.")
        return res



In [None]:
data_dict = mat_to_numpy_dict(mat_data, verbose=False)


In [None]:
def get_experiment(data_dict, exp_name, rat_name=None):
    exp_dict = data_dict[exp_name]
    rat_names = exp_dict["ratName"]
    rat_idx = 0 if rat_name not in rat_names else rat_names.index(rat_name)
    res_dict = exp_dict["controlSummaryData"][rat_idx]

    def get_responses(res_dict, stim_idx):
        high_resp = res_dict["nHighResponses"][stim_idx]
        low_resp = res_dict["nTrials"][stim_idx] - high_resp
        return np.vstack((high_resp, low_resp))

    conditions = res_dict["condition"]
    stim_rates = res_dict["stimRates"][0]
    assert np.all(stim_rates == np.arange(9, 17))
    auditory = get_responses(res_dict, conditions.index("Auditory"))
    visual = get_responses(res_dict, conditions.index("Visual"))
    multisensory = get_responses(res_dict, conditions.index("Multisensory"))
    return stim_rates, auditory, visual, multisensory



In [None]:
stim_rates, auditory, visual, multisensory = get_experiment(
    data_dict, "multisensory", "metaRat"
)


In [None]:
def psychometric_fn(mu, sigma, gamma, lamda, x):
    return gamma + (1 - gamma - lamda) * norm.cdf(x, mu, sigma)


def psychometric_nll(params, stim_rates, responses):
    mu, sigma, gamma, lamda = params
    p = psychometric_fn(mu, sigma, gamma, lamda, stim_rates)
    return -np.sum(responses[0] * np.log(p) + responses[1] * np.log(1 - p))


def psychometric_fit(
    stim_rates,
    responses,
    init=(12.5, 1, 0, 0),
    bounds=((None, None), (None, None), (0, 1), (0, 1)),
):
    res = minimize(psychometric_nll, init, args=(stim_rates, responses), bounds=bounds)
    return res.x


def psychometric_plot(stim_rates, responses, params, labels):
    x = np.linspace(min(stim_rates), max(stim_rates), 100)
    for (mu, sigma, gamma, lamda), label in zip(params, labels):
        y = psychometric_fn(mu, sigma, gamma, lamda, x)
        plt.plot(x, y, label=label)
    for resp in responses:
        plt.scatter(stim_rates, resp[0] / resp.sum(axis=0), marker="x", color="k")
    plt.ylim(0, 1)
    plt.legend()



In [None]:
params_auditory = psychometric_fit(stim_rates, auditory)
params_visual = psychometric_fit(stim_rates, visual)
params_multisensory = psychometric_fit(stim_rates, multisensory)


In [None]:
params = (params_auditory, params_visual, params_multisensory)
resps = (auditory, visual, multisensory)
labels = ("Auditory", "Visual", "Multisensory")

psychometric_plot(stim_rates, resps, params, labels)
