# Run the Pearson correlation analysis on the Opioids dataset

This notebook performs the subject-level and group-level Pearson correlation analyses of
the preprocessed Opioids dataset. These analyses consists of Pearson correlation
matrices and seed-based maps.

In [None]:
from collections import ChainMap
from operator import itemgetter
from pathlib import Path

import h5py as h5
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from bids import BIDSLayout
from joblib import Parallel, delayed
from tqdm import tqdm
from nilearn import plotting

from opioids_analysis.pearson import (
    compute_subject_level_pearson,
    compute_group_level_pearson,
    read_session_sample_masks,
    write_session_subject_level_pearson,
    read_session_subject_level_pearson,
    write_session_group_level_pearson,
)
from opioids_analysis.plotting import plot_group_level_pearson

## 1 Load template and ROIs

Template and ROIs are stored in the `params/` folder at the root of the repository.

In [None]:
params_path = Path("../params/")
template_path = params_path / "opioids_template.nii.gz"
template_img = nib.load(template_path)
# Set the sform code to 0 so that Nilearn uses qform.
template_img.set_sform(None, code=0)

rois_name = params_path / "rois" / "Mask-autoROIs18-slim.nii.gz"
rois_img = nib.load(rois_name)

if template_img.shape != rois_img.shape:
    raise ValueError("Template and ROIs image shapes do not match!")

brain_mask_img = nib.Nifti1Image(
    (template_img.get_fdata() > 0).astype(int),
    affine=template_img.affine,
    header=template_img.header,
)

fig = plt.figure(dpi=100)
_ = plotting.plot_roi(
    rois_img,
    bg_img=template_img,
    display_mode="y",
    cut_coords=1,
    annotate=False,
    black_bg=True,
    title="Template & ROIs",
    figure=fig,
    aspect=0.1 / 0.11,
)

## 2 Analysis parameters

The following analysis parameters must be set:

- `registered_data_path`: path to the `rawdata/` folder containing the rpreprocessed
  Opioids dataset.
- `sessions`: list of session labels that will be analyzed.
- `opioids_results_root`: root path of the opioids analysis results.
- `sample_masks_path`: path to the HDF5 file containing the sample masks.
- `control_session`: label of the control saline session.
- `fdr_threshold`: Threshold at which to control the FDR during group-level analysis.
- `subject_level_path`: path to the HDF5 file where the subject-level results will be
  stored.
- `group_level_path`: path to the HDF5 file where the group-level results will be
  stored.
- `figures_path`: path to the output folder where figures will be saved.
- `graph_roi_order`: ROI ordering for circular graphs.
- `graph_roi_labels`: ROI labels for circular graphs.
- `n_jobs`: The maximum number of concurrently running jobs. If -1 all CPUs are used.

In [None]:
# Path to the preprocessed fUS-BIDS dataset.
registered_data_path = Path(
    "/mnt/feanor/datasets/opioids/derivatives/registration/derivatives/preprocessed/"
)
layout = BIDSLayout(registered_data_path, validate=False)

# Session labels present in the dataset.
sessions = layout.get_sessions()

# Analysis results root path.
opioids_results_root = Path("/mnt/feanor/home/sdiebolt/opioids-paper-results/")

# Sample masks path.
sample_masks_path = opioids_results_root / "sample_masks.h5"

# Label of the session used as control in statistical comparisons.
control_session = "salineControl"

# Threshold for significance after FDR correction.
fdr_threshold = 0.05

# Output HDF5 file for subject-level and group-level results.
subject_level_path = opioids_results_root / "subject_level_pearson.h5"
group_level_path = opioids_results_root / "group_level_pearson.h5"

# Folder where figures will be saved.
figures_path = group_level_path.parent / "figures"

# ROI ordering and labels for circular graphs
graph_roi_order = (8, 9, 5, 4, 3, 2, 1, 0, 6, 7, 10, 11, 12, 16, 17, 13, 14, 15)
graph_roi_labels = {k: v + 1 for k, v in enumerate(graph_roi_order)}

# The maximum number of concurrently running jobs. If -1 all CPUs are used.
n_jobs = -1

## 3 Run the subject-level Pearson correlation analysis

The subject-level Pearson correlation analysis runs in parallel. You may modify the
`n_jobs` argument below if your computer becomes sluggish while running the analysis.

In [None]:
with Parallel(n_jobs=n_jobs) as parallel:
    for session in tqdm(sessions):
        if subject_level_path.is_file():
            with h5.File(subject_level_path, "r") as f:
                if "seed_maps" in f and session in f["seed_maps"]:
                    continue

        sample_masks = read_session_sample_masks(sample_masks_path, session)

        # Paths are sorted to order them by run index.
        subjects = layout.get_subjects(session=session)
        subject_level_pearson = parallel(
            delayed(compute_subject_level_pearson)(
                nii_paths=sorted(
                    layout.get(subject=subject, session=session, return_type="file")
                ),
                brain_mask_img=brain_mask_img,
                rois_img=rois_img,
                sample_masks=sample_masks[subject],
            )
            for subject in subjects
        )

        # Matrices and maps are saved as numpy arrays for easier operations during the
        # group-level analysis.
        correlation_matrices = np.array(
            [res["correlation_matrices"] for res in subject_level_pearson]
        )
        seed_maps = [res["seed_maps"] for res in subject_level_pearson]
        subject_level_pearson = {
            "correlation_matrices": dict(zip(subjects, np.array(correlation_matrices))),
            "seed_maps": dict(zip(subjects, np.array(seed_maps))),
        }

        write_session_subject_level_pearson(
            subject_level_path,
            session,
            subject_level_pearson,
        )

## 4 Create the `salineControl` session

Saline sessions `saline`, `saline2`, `WTFS1`, and `WTMS1`, are merged into a single
session that will constitute the control in the second-level analysis.

In [None]:
saline_sessions = ["saline", "saline2", "WTFS1", "WTMS1"]

with h5.File(subject_level_path, "r") as f:
    missing_pearson = [s for s in saline_sessions if s not in f["seed_maps"]]

if missing_pearson:
    raise RuntimeError(
        "The following saline pearson are missing from the HDF5 file: "
        f"{missing_pearson}."
    )

saline_pearson = [
    read_session_subject_level_pearson(subject_level_path, s) for s in saline_sessions
]

correlation_matrices = map(itemgetter("correlation_matrices"), saline_pearson)
seed_maps = map(itemgetter("seed_maps"), saline_pearson)
saline_control_pearson = {
    "correlation_matrices": dict(ChainMap(*correlation_matrices)),
    "seed_maps": dict(ChainMap(*seed_maps)),
}

write_session_subject_level_pearson(
    subject_level_path,
    control_session,
    saline_control_pearson,
)

## 5 Run the group-level Pearson correlation analysis

The group-level Pearson correlation analysis runs in parallel. You may modify the
`n_jobs` argument below if your computer becomes sluggish while running the analysis.

In [None]:
brain_mask = brain_mask_img.get_fdata().squeeze().astype(bool)

with h5.File(subject_level_path, "r") as f:
    sessions = list(f["seed_maps"].keys())
    # Removing control session since we don't want to compare it to itself.
    sessions.remove(control_session)

for session in tqdm(sessions):
    if group_level_path.is_file():
        with h5.File(group_level_path, "r") as f:
            if "seed_maps" in f and session in f["seed_maps"]:
                continue

    group_level_pearson = compute_group_level_pearson(
        subject_level_path=subject_level_path,
        session_treatment=session,
        session_control=control_session,
        brain_mask=brain_mask,
        fdr_threshold=fdr_threshold,
        n_jobs=n_jobs,
    )

    write_session_group_level_pearson(group_level_path, session, group_level_pearson)

## 6 Plot the group-level Pearson correlation figures

Plots are generated in parallel. You may modify the `n_jobs` argument
below if your computer becomes sluggish during plotting.

In [None]:
with h5.File(group_level_path, "r") as f:
    sessions = list(f["seed_maps"].keys())

_ = Parallel(n_jobs=n_jobs)(
    delayed(plot_group_level_pearson)(
        group_level_path=group_level_path,
        session=session,
        template_img=template_img,
        rois_img=rois_img,
        graph_roi_order=graph_roi_order,
        graph_roi_labels=graph_roi_labels,
        output_path=figures_path,
    )
    for session in sessions
)