# Run the cross-correlation analysis on the Opioids dataset

This notebook performs the subject-level and group-level cross-correlation analyses of
the preprocessed Opioids dataset. These analyses are performed independently from the
Pearson correlation analyses from `pearson_analysis.ipynb` because of the large amount
of data generated by the cross-correlation maps (~10 Gb per session).

In [1]:
from collections import ChainMap
from operator import itemgetter
from pathlib import Path

import h5py as h5
import numpy as np
import nibabel as nib
from bids import BIDSLayout
from joblib import Parallel, delayed
from tqdm import tqdm

from opioids_analysis.cross_correlation import (
    compute_subject_level_xcorr,
    compute_group_level_xcorr,
    write_session_subject_level_xcorr,
    read_session_subject_level_xcorr,
    write_session_group_level_xcorr,
)
from opioids_analysis.plotting import plot_group_level_xcorr

## 1 Load template and ROIs

Template and ROIs are stored in the `params/` folder at the root of the repository.

In [2]:
params_path = Path("../params/")
template_path = params_path / "opioids_template.nii.gz"
template_img = nib.load(template_path)
# Set the sform code to 0 so that Nilearn uses qform.
template_img.set_sform(None, code=0)

rois_name = params_path / "rois" / "Mask-autoROIs18-slim.nii.gz"
rois_img = nib.load(rois_name)

if template_img.shape != rois_img.shape:
    raise ValueError("Template and ROIs image shapes do not match!")

brain_mask_img = nib.Nifti1Image(
    (template_img.get_fdata() > 0).astype(int),
    affine=template_img.affine,
    header=template_img.header,
)

## 2 Analysis parameters

The following analysis parameters must be set:

- `registered_data_path`: path to the `rawdata/` folder containing the rpreprocessed
  Opioids dataset.
- `sessions`: list of session labels that will be analyzed.
- `opioids_results_root`: root path of the opioids analysis results.
- `max_lag`: Maximum cross-correlation lag, in seconds.
- `control_session`: label of the control saline session.
- `fdr_threshold`: Threshold at which to control the FDR during group-level analysis.
- `subject_level_path`: path to the HDF5 file where the subject-level cross-correlation
  maps will be stored.
- `group_level_path`: path to the HDF5 file where the group-level cross-correlation
  maps will be stored.
- `n_jobs`: The maximum number of concurrently running jobs. If -1 all CPUs are used.

In [3]:
# Path to the preprocessed fUS-BIDS dataset.
registered_data_path = Path(
    "/mnt/feanor/datasets/opioids/derivatives/registration/derivatives/preprocessed/"
)
layout = BIDSLayout(registered_data_path, validate=False)

# Session labels present in the dataset.
sessions = ["WTM70", "WTMS1", "WTFS1", "saline", "saline2", "WTSC11"]

# Analysis results root path.
opioids_results_root = Path("/mnt/feanor/home/sdiebolt/opioids-paper-results/")

# Maximum cross-correlation lag.
max_lag = 20

# Label of the session used as control in statistical comparisons.
control_session = "salineControl"

# Threshold for significance after FDR correction.
fdr_threshold = 0.05

# Output HDF5 file for subject-level and group-level cross-correlation maps.
subject_level_path = opioids_results_root / "subject_level_xcorr.h5"
group_level_path = opioids_results_root / "group_level_xcorr.h5"

# Folder where figures will be saved.
figures_path = group_level_path.parent / "figures"

# The maximum number of concurrently running jobs. If -1 all CPUs are used.
n_jobs = -1

## 3 Run the subject-level cross-correlation analysis

The subject-level cross-correlation analysis runs in parallel. You may modify the
`n_jobs` argument below if your computer becomes sluggish while running the analysis.

In [None]:
with Parallel(n_jobs=n_jobs) as parallel:
    for session in tqdm(sessions):
        if subject_level_path.is_file():
            with h5.File(subject_level_path, "r") as f:
                if "xcorr_maps" in f and session in f["xcorr_maps"]:
                    continue

        subjects = layout.get_subjects(session=session)
        subject_level_xcorr = parallel(
            delayed(compute_subject_level_xcorr)(
                nii_paths=sorted(
                    layout.get(subject=subject, session=session, return_type="file")
                ),
                brain_mask_img=brain_mask_img,
                rois_img=rois_img,
                max_lag=max_lag,
            )
            for subject in subjects
        )

        # Maps are saved as numpy arrays for easier operations during the group-level
        # analysis.
        xcorr_maps = [res["xcorr_maps"] for res in subject_level_xcorr]
        subject_level_xcorr = {
            "xcorr_maps": dict(zip(subjects, np.array(xcorr_maps))),
        }

        write_session_subject_level_xcorr(
            subject_level_path,
            session,
            subject_level_xcorr,
        )
        del subject_level_xcorr

## 4 Create the `salineControl` session

Saline sessions `saline`, `saline2`, `WTFS1`, and `WTMS1`, are merged into a single
session that will constitute the control in the second-level analysis.

In [None]:
saline_sessions = ["saline", "saline2", "WTFS1", "WTMS1"]

with h5.File(subject_level_path, "r") as f:
    missing_results = [s for s in saline_sessions if s not in f["xcorr_maps"]]

if missing_results:
    raise RuntimeError(
        "The following saline cross-correlation maps are missing from the HDF5 file: "
        f"{missing_results}."
    )

saline_xcorr = [
    read_session_subject_level_xcorr(subject_level_path, s) for s in saline_sessions
]

saline_control_xcorr = {
    "xcorr_maps": dict(ChainMap(*map(itemgetter("xcorr_maps"), saline_xcorr)))
}
del saline_xcorr

write_session_subject_level_xcorr(
    subject_level_path, control_session, saline_control_xcorr
)
del saline_control_xcorr

## 5 Run the group-level cross-correlation analysis

The group-level cross-correlation analysis runs in parallel. You may modify the
`n_jobs` argument below if your computer becomes sluggish while running the analysis.

In [None]:
brain_mask = brain_mask_img.get_fdata().squeeze().astype(bool)

with h5.File(subject_level_path, "r") as f:
    sessions = list(f["xcorr_maps"].keys())

for session in tqdm(sessions):
    if group_level_path.is_file():
        with h5.File(group_level_path, "r") as f:
            if "xcorr_maps" in f and session in f["xcorr_maps"]:
                continue

    group_level_xcorr = compute_group_level_xcorr(
        subject_level_path=subject_level_path,
        session_treatment=session,
        session_control=control_session,
        brain_mask=brain_mask,
        fdr_threshold=fdr_threshold,
        n_jobs=n_jobs,
    )

    write_session_group_level_xcorr(group_level_path, session, group_level_xcorr)

## 6 Plot the group-level cross-correlation figures

Plots are generated in parallel. You may modify the `n_jobs` argument
below if your computer becomes sluggish during plotting.

In [4]:
with h5.File(group_level_path, "r") as f:
    sessions = list(f["xcorr_maps"].keys())

_ = Parallel(n_jobs=n_jobs)(
    delayed(plot_group_level_xcorr)(
        group_level_path=group_level_path,
        session=session,
        template_img=template_img,
        rois_img=rois_img,
        lags_to_skip=7,
        output_path=figures_path,
    )
    for session in sessions
)