# Plot RSA results + statistics + first/second half

- based on the `rsa_analysis_9x9.ipynb` notebook
- based on statistics from `clusterperm.py` script

---

1. Select which model to look at
1. Plot the data across sampling conditions
1. Find significant clusters
1. Plot clusters
1. Plot 2x2 "summaries" for each significant cluster
    1. Repeat for "first_half" and "second_half" of trials
    1. Add statistics for the "summaries"
    1. Add post-hoc tests for significant statistics


---

Also plot the RDMs:

- for the significant observed clusters
    - per subject
        - per task
        - averaged over tasks
    - averaged over subjects
        - per task
        - averaged over tasks

In [None]:
import glob
import os
import os.path as op
import itertools
import json
import multiprocessing

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pingouin
import seaborn as sns
import scipy.stats
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable

from clusterperm import (
    evaluate_significance,
    return_observed_clusters,
    _return_clusters,
)
from utils import BIDS_ROOT

# Settings and paths

In [None]:
# File path template
path_template = op.join(BIDS_ROOT, "derivatives", "rsa_{}", "{}")

In [None]:
# Settings
modelname = "numberline"
orth = False
thresh = 0.05
cluster_defining_thresh_ave = 0.05  # cluster defining threshold for average effect
clusterstat = "length"  # "length" or "mass"
clusterthresh = 0.05
rsa_dims = "9x9"
nth_perm_results = 0  # if multiple perm result distrs, which one to take?
do_2x2_with_halfs_ecFalse = False

# some more settings for one sample ttest
niterations = 10_000
y = 0

# set njobs to 1 if you do not want to run in parallel
njobs = min(multiprocessing.cpu_count() - 1, 25)

# which RSA to do the analyses on
# This is the folder name of the results nested in the 9x9 ... folder
rsa_kind = (
    "cv-False_flip-False_rsa-pearson_dist-euclidean_half-both_exclude-Identity_mnn-FalseFalseFalse_ec-False_c-(0.6, 1.6)t--0.8s-150b-(None, 0)s-True_outcome"
)

In [None]:
# Try to find settings that would overwrite the settings here
configpath = op.join(BIDS_ROOT, "code", "rsa_plotting_config.json")
if op.exists(configpath):
    print(f"\n\nOVERWRITTING SETTINGS USING CONFIG:\n{configpath}\n\n")
    with open(configpath, "r") as fin:
        config = json.load(fin)
    print(config)
    rsa_dims = config.get("rsa_dims", rsa_dims)
    rsa_kind = config.get("rsa_kind", rsa_kind)
    modelname = config.get("modelname", modelname)
    orth = config.get("orth", orth)
    nth_perm_results = config.get("nth_perm_results", nth_perm_results)
    clusterstat = config.get("clusterstat", clusterstat)
    do_2x2_with_halfs_ecFalse = config.get(
        "do_2x2_with_halfs_ecFalse", do_2x2_with_halfs_ecFalse
    )

In [None]:
rsa_folder = path_template.format(rsa_dims, rsa_kind)

# Finish formatting file paths
results = op.join(rsa_folder, "rsa_results.csv")
results_first_half = results.replace("half-both", "half-first")
results_second_half = results.replace("half-both", "half-second")

if do_2x2_with_halfs_ecFalse:
    results_first_half = results_first_half.replace("ec-True", "ec-False")
    results_second_half = results_second_half.replace("ec-True", "ec-False")

full_modelname = "orth" + modelname if orth else modelname

cluster_distr_files = glob.glob(
    op.join(
        rsa_folder,
        "clusterperm_results_*T*",
        f"model-{full_modelname}_stat-{clusterstat}_thresh-{thresh}_distr.npy",
    )
)

if len(cluster_distr_files) == 1:
    cluster_distr_file = cluster_distr_files[0]
elif len(cluster_distr_files) > 1:
    cluster_distr_file = sorted(cluster_distr_files)[nth_perm_results]
    print(f"found multiple cluster perm results distrs: {cluster_distr_files}")
    print(f"picking index {nth_perm_results}")
else:
    raise RuntimeError("no clusterperm_results found.")

# For column order, we have a script later that makes sure column order
# is always the same, also if we get multiple files.
cluster_distr_col_order_file = glob.glob(
    op.join(
        rsa_folder,
        "clusterperm_results_*T*",
        f"model-{full_modelname}_stat-{clusterstat}_thresh-{thresh}_distr_column_order.txt",
    )
)

In [None]:
# Where to save results
outfolder = op.join(
    rsa_folder, "perm_and_2x2_outputs", full_modelname + f"_{clusterstat}"
)
os.makedirs(outfolder, exist_ok=True)

# And where to save results for RDM plots
model = "o_" + modelname if orth else modelname
outfolder_rdms = op.join(rsa_folder, "perm_and_RDM_outputs", model)
os.makedirs(outfolder_rdms, exist_ok=True)

# Load RSA results

In [None]:
df_rsa = pd.read_csv(results)
df_rsa_first_half = pd.read_csv(results_first_half)
df_rsa_second_half = pd.read_csv(results_second_half)

In [None]:
# find the rsa_method that was used
assert (
    df_rsa["method"].nunique() == 1
), f"more than one rsa_method detected: {df_rsa['method']}"
rsa_method = df_rsa["method"].unique()[0]

# find the distance metric that was used
assert (
    df_rsa["distance_metric"].nunique() == 1
), f"more than one distance_metric detected: {df_rsa['distance_metric']}"
distance_metric = df_rsa["distance_metric"].unique()[0]

# Get the data of interest

In [None]:
df_rsa_model_orth = df_rsa[(df_rsa["orth"] == orth) & (df_rsa["model"] == modelname)]

# plot the effect

In [None]:
with sns.plotting_context("talk"):

    # Plot this group
    fig, ax = plt.subplots(sharex=True, sharey=True, figsize=(8, 8))

    sns.lineplot(
        x="time_s",
        y="similarity",
        ci=68,
        hue="stopping",
        style="sampling",
        hue_order=["fixed", "variable"],
        style_order=["active", "yoked"],
        legend=False,
        data=df_rsa_model_orth,
        ax=ax,
    )

    ax.set_title(model)
    ax.axhline(0, color="black")
    ax.axvline(0.0, color="black", linestyle="--")
    ax.set_xlabel("time (s)")
    ax.set_ylabel(f"similarity ({rsa_method})")

    # add legend
    # https://matplotlib.org/3.1.1/gallery/text_labels_and_annotations/custom_legends.html
    legend_elements = [
        Line2D(
            [0],
            [0],
            color=sns.color_palette()[1],
            linestyle="-",
            label="active/variable",
        ),
        Line2D(
            [0],
            [0],
            color=sns.color_palette()[1],
            linestyle="--",
            label="yoked/variable",
        ),
        Line2D(
            [0], [0], color=sns.color_palette()[0], linestyle="-", label="active/fixed"
        ),
        Line2D(
            [0], [0], color=sns.color_palette()[0], linestyle="--", label="yoked/fixed"
        ),
    ]

    legend = plt.legend(
        handles=legend_elements,
        loc="best",
        prop={"size": 10},
        framealpha=1,
    )
    ax.add_artist(legend)

    fig.tight_layout()

    fname = op.join(outfolder, "effect_plot.pdf")
    fig.savefig(fname)

# find significant clusters

In [None]:
# find the observed clusters in the data
clusters_obs, models_obs = return_observed_clusters(
    df_rsa_model_orth, thresh, pingouin.mixed_anova
)

In [None]:
# Load cluster distribution array
cluster_distr = np.load(cluster_distr_file)

# to turn into data frame, get order of the columns in the array
if len(cluster_distr_col_order_file) > 0:
    column_orders = []
    for thisfile in cluster_distr_col_order_file:
        with open(thisfile, "r") as fin:
            lines = fin.readlines()
        column_orders.append([line.strip() for line in lines])

    column_order = column_orders[0]
    # all read in column order must be the same
    assert all([column_order == i for i in column_orders])

else:
    # NOTE: If we fail to get the column order from a file,
    #       we assume it's the same as usual. this is
    #       probably safe, and we can double check the results
    effect_order = ["stopping", "sampling", "Interaction"]
    column_order = [i.lower() for i in effect_order]

cluster_distr = pd.DataFrame(cluster_distr, columns=column_order)

In [None]:
# evaluate the significance of the observed clusters
(clustersig_threshs, clusters_obs_stats, clusters_obs_sig,) = evaluate_significance(
    cluster_distr, clusters_obs, clusterstat, clusterthresh, models_obs
)

In [None]:
# Calculate and save p values to file
# p value is number of test stats bigger or as big as observed stat, ...
# divided by number of permutations
pval_lines = []
for effect in clusters_obs_stats:

    effect_distr = cluster_distr[effect].to_numpy()

    for i, stat in enumerate(clusters_obs_stats[effect]):

        pval = (1 + np.sum(effect_distr >= stat)) / (1 + len(effect_distr))
        pval_lines.append(
            f"effect: {effect}, observed cluster {i+1} (ascending), p={pval}\n"
        )


# write pvalue as txt
pval_fname = cluster_distr_file.replace("distr.npy", "pvals.txt")
with open(pval_fname, "w") as fout:
    fout.write("".join(pval_lines))

# plot significant clusters

In [None]:
patches = []
for effect, clusters in clusters_obs_sig.items():
    if len(clusters) < 1:
        significant = []
        continue  # nothing to do if no significant clusters
    elif len(clusters) == 1:
        significant = clusters[0]
    else:
        _summary = "\n".join(
            [f"{len(iclu)}: from {min(iclu)} to {max(iclu)} (idx)" for iclu in clusters]
        )
        print(
            f"found several clusters for {effect}, see by length:\n{_summary}\n...plotting all."
        )
        significant = np.concatenate(clusters)

    xs = df_rsa_model_orth["time_s"].unique()
    xs = xs[significant]
    ys = np.repeat(ax.get_ylim()[0], xs.shape[-1])
    color = dict(zip(column_order, ["r", "g", "m"]))[effect]
    (line,) = ax.plot(xs, ys, marker=".", color=color, linestyle="None")
    line.set_label(effect)


ax.legend(loc=2, title=f"p < {clusterthresh}")

fname = op.join(outfolder, "effect_perm_plot.pdf")
fig.savefig(fname)

fig

# Find significant "average" effect

that is, average the effect over the 4 sampling conditions, and perform a 1-sample ttest permutation test for siginicant clusters.

In [None]:
# pick the kind of model we want
average_effect_model = modelname  # e.g., "numberline"
average_effect_orth = orth  # e.g., False
average_effect_modelname = (
    "o" + average_effect_model if average_effect_orth else average_effect_model
)

df_modelorth = df_rsa[
    (df_rsa["model"] == average_effect_model) & (df_rsa["orth"] == average_effect_orth)
]

In [None]:
# take mean over tasks for each subject
df_modelorthmean = df_modelorth.groupby(["subject", "itime"]).mean().reset_index()

In [None]:
# Convert dataframe to matrix
nsubjs = df_modelorthmean["subject"].nunique()
ntimes = df_modelorthmean["itime"].nunique()
X = np.full((nsubjs, ntimes), np.nan)
for meta, grp in df_modelorthmean.groupby(["subject", "itime"]):
    subj, itime = meta
    isubj = subj - 1
    X[isubj, itime] = grp["similarity"].to_numpy()

# sanity check df to matrix conversion
for i in range(100):
    testtime = np.random.choice(range(ntimes))
    testsubj = np.random.choice(range(nsubjs))
    testsimilarity = X[testsubj, testtime]

    tmp = df_modelorthmean[
        (df_modelorthmean["subject"] == testsubj + 1)
        & (df_modelorthmean["itime"] == testtime)
    ]["similarity"].to_numpy()[0]

    np.testing.assert_allclose(testsimilarity, tmp)

In [None]:
def calc_stats_per_timepoint(X, y, rng):
    """Calculate one-sample t-test statistic per timepoint in x against y.

    Parameters
    ----------
    X : ndarray, shape(subjects, timepoints)
        The similarities per subject and timepoint.
    y : float
        The value against which to perform a one-sample t-test.
    rng : np.random.RandomState | None
        A stable rng object for seeding. If an rng object is passed, the
        signs in X will be randomly flipped (permuted). If None is passed,
        the data are not permuted.

    Returns
    -------
    stats : ndarray, shape(2, timepoints)
        The T values on dim 0,t and p values of dim 1,t for each timepoint t.

    """
    X = X.copy()
    nsubjs, ntimes = X.shape
    if rng is not None:
        flip = rng.choice([-1, 1], size=((nsubjs, ntimes)), replace=True)
        X *= flip

    stats = np.zeros((2, ntimes))
    for itime in range(ntimes):
        tstat, pval = scipy.stats.ttest_1samp(X[..., itime], y)
        stats[0, itime] = tstat
        stats[1, itime] = pval

    return stats

In [None]:
def gen_null_distr(niterations, X, y, rng, thresh):
    """Generate a null distribution."""
    distr = np.zeros(niterations)
    for iteration in range(niterations):

        # Calculate stats for this iteration
        stats = calc_stats_per_timepoint(X, y, rng)

        # get clusters for this iteration
        clusters = _return_clusters(stats[1, ...] < thresh)

        # Save max cluster statistic
        # use length ... if no clusters found, maximum length is 0
        maxstat = 0
        _clu_lengths = [len(cluster) for cluster in clusters]
        if len(_clu_lengths) > 0:
            maxstat = np.max(_clu_lengths)

        distr[iteration] = maxstat

    return distr

In [None]:
# Generate the distribution for the null hypothesis
# done in parallel
with multiprocessing.Pool(njobs) as pool:
    niter = int(np.ceil(niterations / njobs))
    rngs = [np.random.RandomState(i) for i in range(njobs)]

    inputs = itertools.product([niter], [X], [y], rngs, [cluster_defining_thresh_ave])

    results = pool.starmap(gen_null_distr, inputs)

distr = np.hstack(results)

In [None]:
# Calculate the observed statistics
observed_stats = calc_stats_per_timepoint(X, y, rng=None)
observed_clusters = _return_clusters(observed_stats[1, ...] < cluster_defining_thresh_ave)

# evaluate pvalue of observed clusters against null distribution
pvals = [
    (1 + np.sum(distr >= len(cluster))) / (1 + len(distr))
    for cluster in observed_clusters
]

observed_clusters_sig = [
    clu for clu, p in zip(observed_clusters, pvals) if p < clusterthresh
]

# save clusters for later plotting
general_average_effect_clusters = {}
for i, (clu, p) in enumerate(zip(observed_clusters, pvals)):

    # compute stats (tval, dof, pval, cohens d)
    sel = df_modelorthmean[df_modelorthmean["itime"].isin(clu)]
    sel = sel[["subject", "similarity", "itime"]]
    sel = sel.groupby(["subject"]).mean().reset_index()
    X = sel["similarity"].to_numpy()
    assert len(X) == 40
    ttest_results = pingouin.ttest(x=X, y=0)
    ttest_stats = dict(
        tval=float(ttest_results["T"][0]),
        dof=float(ttest_results["dof"][0]),
        pval=float(ttest_results["p-val"][0]),
        cohend=float(ttest_results["cohen-d"][0]),
    )

    # add the data
    tstart = df_rsa_model_orth["time_s"].unique()[clu[0]]
    tstop = df_rsa_model_orth["time_s"].unique()[clu[-1]]
    general_average_effect_clusters[i] = {
        "p-val": p,
        "cluster": clu,
        "tstart": tstart,
        "tstop": tstop,
        "ttest-stats": ttest_stats,
    }

fname = op.join(outfolder, f"average_{average_effect_modelname}_clusters.json")
with open(fname, "w") as fout:
    json.dump(general_average_effect_clusters, fout)

In [None]:
# add significance lines to plot as well
patches = []
average_effect = f"average_{average_effect_modelname}"
for effect, clusters in {average_effect: observed_clusters_sig}.items():
    if len(clusters) < 1:
        significant = []
        continue  # nothing to do if no significant clusters
    elif len(clusters) == 1:
        significant = clusters[0]
    else:
        _summary = "\n".join(
            [f"{len(iclu)}: from {min(iclu)} to {max(iclu)} (idx)" for iclu in clusters]
        )
        print(
            f"found several clusters for {effect}, see by length:\n{_summary}\n...plotting all."
        )
        significant = np.concatenate(clusters)

    xs = df_rsa_model_orth["time_s"].unique()
    xs = xs[significant]
    ys = np.repeat(ax.get_ylim()[0], xs.shape[-1])
    color = "y"
    (line,) = ax.plot(xs, ys, marker=".", color=color, linestyle="None")
    line.set_label(effect)


ax.legend(loc=2, title=f"p < {clusterthresh}")

fname = op.join(outfolder, f"effect_perm_plot_plus_{average_effect_modelname}.pdf")
fig.savefig(fname)

fig

# Plot RDMs over averaged significant cluster windows

- for the significant observed clusters
    - per subject
        - per task
        - averaged over tasks
    - averaged over subjects
        - per task
        - averaged over tasks

In [None]:
def prep_to_plot(rdm):
    """Remove upper triangle from rdm."""
    tri_idx = np.triu_indices(rdm.shape[0])
    tmprdm = rdm.copy()
    tmprdm[tri_idx] = np.nan
    return tmprdm

In [None]:
subjects = range(1, 41)
tasks = ("AF", "AV", "YF", "YV")

rdms_folder_template = op.join(
    rsa_folder, "single_subj_plots", "sub-{:02}_task-{}_rdm_times.npy"
)

In [None]:
rdm_plotting_dict = dict(clusters_obs_sig)
rdm_plotting_dict.update({average_effect: observed_clusters_sig})

# go over effects
for effect, clusters in rdm_plotting_dict.items():
    if len(clusters) < 1:
        significant = []
        continue  # nothing to do if no significant clusters
    elif len(clusters) == 1:
        significant = clusters[0]
    else:
        _summary = "\n".join(
            [f"{len(iclu)}: from {min(iclu)} to {max(iclu)} (idx)" for iclu in clusters]
        )
        print(
            f"found several clusters for {effect}, see by length:\n{_summary}\n...picking the largest."
        )
        largest_cluster_idx = np.argmax([len(iclu) for iclu in clusters])
        significant = clusters[largest_cluster_idx]

    # Go over subjects and tasks
    # save RDMs in a large array for later averaging
    all_rdms = np.full((len(subjects), len(tasks), 9, 9), np.nan)
    did_not_find = 0
    for isubj, subj in enumerate(subjects):

        for itask, task in enumerate(tasks):

            fname = rdms_folder_template.format(subj, task)
            if not op.exists(fname):
                did_not_find += 1
                continue

            rdm_times = np.load(fname)
            rdm_average = np.mean(rdm_times[..., significant], axis=-1)
            all_rdms[isubj, itask, ...] = rdm_average

    if did_not_find == len(subjects) * 2:
        pass  # everything as expected, 2 of 4 tasks for each subj should be skipped
    elif did_not_find == len(subjects) * len(tasks):
        print(f"\n plotting RDMs for effect: {effect}")
        print(
            "did not save the rdm_times.npy array in a previous run. "
            "need to re-run the rsa_analysis with newer version of the script."
        )
        continue
    else:
        raise ValueError(f"unexpected number of rdm_times.npy found: {did_not_find}")

    # all_rdms is (40 x 4 x 9 x 9) containing the RDMs averaged over the
    # significant cluster window ... half of the arrays are NaN, because each
    # of the 40 subjs only had 2, instead of 4 tasks

    # Make RDM over all subjects
    # 2x2 plot (4 panels)
    subj_mean = np.nanmean(all_rdms, axis=0)

    fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))
    for itask, task in enumerate(tasks):
        ax = axs.flat[itask]
        img = ax.imshow(prep_to_plot(subj_mean[itask, ...]))
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(img, cax=cax, label=distance_metric)
        ax.axis("off")
        ax.set_title(task)

    fig.suptitle("Mean over subjects", y=1.02)
    fig.tight_layout()

    fname = op.join(outfolder_rdms, effect, "average_subj_2x2.pdf")
    os.makedirs(op.join(outfolder_rdms, effect), exist_ok=True)
    fig.savefig(fname)

    # Make RDM over all subjects and tasks
    # 1x1 plot (1 panel)
    subj_task_mean = np.nanmean(subj_mean, axis=0)

    fig, ax = plt.subplots(figsize=(10, 10))
    img = ax.imshow(prep_to_plot(subj_task_mean))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(img, cax=cax, label=distance_metric)
    ax.axis("off")
    ax.set_title("Mean over subjects and tasks")

    fname = op.join(outfolder_rdms, effect, "average_subj_task_1x1.pdf")
    fig.savefig(fname)

    # Make RDM for each subject over tasks
    # 8x5 plot (40 panels)
    task_mean = np.nanmean(all_rdms, axis=1)

    fig, axs = plt.subplots(8, 5, sharex=True, sharey=True, figsize=(10, 10))
    for isubj, subj in enumerate(subjects):
        ax = axs.flat[isubj]
        img = ax.imshow(prep_to_plot(task_mean[isubj, ...]))
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(img, cax=cax, label=distance_metric)
        ax.axis("off")
        ax.set_title(subj)

    fig.suptitle("Mean over tasks", y=1.01)
    fig.tight_layout()

    fname = op.join(outfolder_rdms, effect, "average_task_8x5.pdf")
    fig.savefig(fname)

    # Make 4 RDMs for each subject and task
    # 4 5x4 plots (20 panels per plot)
    for itask, task in enumerate(tasks):
        rdms = all_rdms[:, itask, ...]

        # remove the nan rdms
        to_remove = []
        for irdm, rdm in enumerate(rdms):
            if np.all(np.isnan(rdm)):
                to_remove.append(irdm)

        rdms = rdms[~np.array(to_remove), ...]

        if "F" in task:
            subj_label = list(range(1, 41, 2))
        elif "V" in task:
            subj_label = list(range(2, 41, 2))

        fig, axs = plt.subplots(5, 4, sharex=True, sharey=True, figsize=(10, 10))
        for isubj_representation in range(20):
            ax = axs.flat[isubj_representation]
            img = ax.imshow(prep_to_plot(rdms[isubj_representation, ...]))
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(img, cax=cax, label=distance_metric)
            ax.set_title(subj_label[isubj_representation])
            ax.axis("off")

        fig.suptitle(f"task: {task}", y=1.01)
        fig.tight_layout()

        fname = op.join(outfolder_rdms, effect, f"task-{task}_5x4.pdf")
        fig.savefig(fname)

# 2x2 "summaries" (plots + stats)

In [None]:
two_x_two_dict = dict(clusters_obs_sig)
two_x_two_dict.update({average_effect: observed_clusters_sig})

split_data = {}
for effect, clusters in two_x_two_dict.items():
    if len(clusters) < 1:
        continue

    split_data[effect] = {}
    for icluster, cluster in enumerate(clusters):

        # get three kinds of splits of the data in terms of
        # how many samples from each trial are included
        data = {
            "both": df_rsa,
            "first_half": df_rsa_first_half,
            "second_half": df_rsa_second_half,
        }
        dfs = {}
        for half, tmp in data.items():
            grp = tmp[(tmp["orth"] == orth) & (tmp["model"] == modelname)]

            # take mean across significant cluster window
            time_idxs = grp["itime"].unique()[cluster]

            tmp = grp[grp["itime"].isin(time_idxs)]
            tmp = (
                tmp.groupby(["subject", "stopping", "sampling", "orth", "model"])
                .mean()
                .reset_index()
            )
            tmp = tmp[
                ["subject", "stopping", "sampling", "orth", "model", "similarity"]
            ]

            # tmp "save"
            dfs[half] = tmp

        # save to overall data dict
        split_data[effect][icluster] = dfs

In [None]:
# Make plots
do_posthocs = []
for effect in split_data:
    for icluster in split_data[effect]:

        # make a figure
        fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(10, 5))

        for isplit, split in enumerate(split_data[effect][icluster]):
            df = split_data[effect][icluster][split]
            ax = axs.flat[isplit]

            combination = f"{effect}-{icluster}-{split}"
            print(combination)
            fname = op.join(outfolder, "anova_" + combination + ".html")
            fout = open(fname, "w")
            print(combination, file=fout)

            # stats
            model = pingouin.mixed_anova(
                data=df,
                dv="similarity",
                within="sampling",
                subject="subject",
                between="stopping",
            )

            print(model.to_html(), file=fout)
            fout.close()

            if effect in ["sampling", "stopping", "interaction"]:
                pval = model[model["Source"].str.lower() == effect]["p-unc"].to_numpy()[
                    0
                ]

                if pval < clusterthresh:
                    do_posthocs.append((effect, icluster, split))
            else:
                # we are dealing with an "average effect", so we do a posthoc
                # test no matter what the p values are for sampling, stopping,
                # and interaction
                do_posthocs.append((effect, icluster, split))

            # special treatment for AV YV and AF YF comparisons
            # this treatment is not needed if we deal with an "average effect"
            if effect in ["sampling", "stopping", "interaction"]:
                ttest_pvals = {}
                for _stopping in ["variable", "fixed"]:
                    x = df[
                        (df["sampling"] == "active") & (df["stopping"] == _stopping)
                    ]["similarity"].to_numpy()
                    y = df[(df["sampling"] == "yoked") & (df["stopping"] == _stopping)][
                        "similarity"
                    ].to_numpy()
                    ttest_results = pingouin.ttest(x, y, paired=True)
                    ttest_pvals[_stopping] = ttest_results["p-val"][0]

            # plot
            sns.pointplot(
                x="sampling",
                y="similarity",
                hue="stopping",
                hue_order=["fixed", "variable"],
                data=df,
                ci=68,
                dodge=True,
                ax=ax,
            )

            # plot a "zero" line
            ax.axhline(0, color="black", linestyle="--")

            if effect in ["sampling", "stopping", "interaction"]:
                title = f"{split}, p={pval:.3f}\n"
                for _stop, pval in ttest_pvals.items():
                    title += (
                        f"\nA{_stop[0].upper()} vs Y{_stop[0].upper()}: p={pval:.3f}"
                    )
            else:
                # we are dealing with an "average effect", so we report
                # all "effects" (sampling, stopping, interaction)
                title = f"{split}"
                for _src, _pval in zip(
                    model["Source"].to_list(), model["p-unc"].to_list()
                ):
                    title += f"\n{_src}: p={_pval:.3f}"

            ax.set_title(title)

            if isplit != 0:
                ax.get_legend().remove()

        fig.suptitle(f"{effect}: cluster {icluster}", y=1.1)

        fname = op.join(outfolder, f"2x2_{effect}_{icluster}.pdf")
        fig.savefig(fname, bbox_inches="tight")

In [None]:
# Make posthoc comparisons for significant 2x2 summaries
# do for each "half split" ... disregarding significance
for effect, icluster, _ in do_posthocs:
    for split in ("both", "first_half", "second_half"):
        df = split_data[effect][icluster][split]

        combination = f"{effect}-{icluster}-{split}"
        print(combination)
        fname = op.join(outfolder, "posthocs_" + combination + ".html")
        fout = open(fname, "w")
        print(combination, file=fout)

        for within_first in (True, False):

            stat = pingouin.pairwise_ttests(
                data=df,
                dv="similarity",
                within="sampling",
                subject="subject",
                between="stopping",
                within_first=within_first,
                padjust="bonf",
                effsize="cohen",
                return_desc=True,
            )

            try:
                display(stat)
            except NameError:
                print(stat)
            print(stat.to_html(), file=fout)

        fout.close()

# Correlation with behavioral accuracy and BNT

In [None]:
# Load behavioral accuracy data
fname_beh_acc = "beh_accs.csv"
accuracy_df = pd.read_csv(fname_beh_acc)

# also load BNT data, as part of the overall behavioral DF
fname_beh_overall = "behavioral_data.csv"
beh_df = pd.read_csv(fname_beh_overall)

# also load CPP accumulation data
fname_accumul = op.join(BIDS_ROOT, "code", "publication_plots", "eeg_accumulation.csv")
accumul_df = pd.read_csv(fname_accumul)
accumul_colname = "mean_amp_diff_late_minus_early"
assert accumul_df.columns[-1] == accumul_colname
accumul_df = accumul_df.rename(columns={accumul_colname: "cpp_accumulation"})

# Drop DESC task for now
accuracy_df = accuracy_df[accuracy_df["task"] != "DESC"].reset_index(drop=True)
beh_df = beh_df[beh_df["task"] != "DESC"].reset_index(drop=True)

# get sampling and stopping from task
accuracy_df["sampling"] = accuracy_df["task"].str[0].map({"A": "active", "Y": "yoked"})
accuracy_df["stopping"] = (
    accuracy_df["task"].str[1].map({"F": "fixed", "V": "variable"})
)

beh_df["sampling"] = beh_df["task"].str[0].map({"A": "active", "Y": "yoked"})
beh_df["stopping"] = beh_df["task"].str[1].map({"F": "fixed", "V": "variable"})

accumul_df["sampling"] = accumul_df["sampling"].str.lower()
accumul_df["stopping"] = accumul_df["stopping"].str.lower()

# work on "experienced" accuracy
accuracy_type = "experienced"
accuracy_df = accuracy_df[accuracy_df["accuracy_type"] == accuracy_type]

# work on behavioral df to make it focussed on bnt only
bnt_df = beh_df[["subject", "stopping", "sampling", "bnt_quartile"]].drop_duplicates()

In [None]:
corr_methods = {
    "correct_choice": "pearson",
    "bnt_quartile": "kendall",
    "cpp_accumulation": "pearson",
}
tail = "two-sided"  # expecting a positive coefficient: better encoding = higher accuracy or BNT

In [None]:
info = f"tail -> {tail}\naccuracy -> {accuracy_type}\n"
print(info)

for effect, icluster, split in do_posthocs:

    combination = f"{effect}-{icluster}-{split}"
    print(combination)
    fname = op.join(outfolder, "corrs_" + combination + ".html")
    fout = open(fname, "w")
    print(combination, file=fout)
    print("<br>" + info, file=fout)

    df = split_data[effect][icluster][split]

    # Merge accuracy, BNT, and CPP-accumulation data onto current df
    tmp = df.merge(accuracy_df, on=["subject", "stopping", "sampling"])
    tmp = tmp.merge(bnt_df, on=["subject", "stopping", "sampling"])
    tmp = tmp.merge(accumul_df, on=["subject", "stopping", "sampling"])

    # Calculate correlations for each sampling condition
    for meta, grp in tmp.groupby(["task"]):
        print(meta)
        print("<br>" + meta, file=fout)

        y = "correct_choice"
        stat_acc = pingouin.correlation.corr(
            grp["similarity"], grp[y], method=corr_methods[y], tail=tail
        )

        y = "bnt_quartile"
        stat_bnt = pingouin.correlation.corr(
            grp["similarity"], grp[y], method=corr_methods[y], tail=tail
        )

        y = "cpp_accumulation"
        stat_cpp = pingouin.correlation.corr(
            grp["similarity"], grp[y], method=corr_methods[y], tail=tail
        )

        tmp_print_info = False
        try:
            display(stat_acc)
            display(stat_bnt)
            display(stat_cpp)
        except NameError:
            tmp_print_info = True

        for y, stat in {
            "correct_choice": stat_acc,
            "bnt_quartile": stat_bnt,
            "cpp_accumulation": stat_cpp,
        }.items():
            print(f"<br> {y}", file=fout)
            print(stat.to_html(), file=fout)
            if tmp_print_info:
                print(f"{y}\n", stat)

    fout.close()