# Run the multimodal analysis on the Opioids dataset

This notebook performs the multimodal analysis of the preprocessed Opioids dataset.
These analysis consists in computing an index of functional connectivity following
Opioids treatment and comparing it to other physiological readouts: cerebral blood
volume, animal motion, analgesia (as measured using the hot plate test), µ-opioid
receptor phosphorylation, and respiratory depression.

In [None]:
import warnings
from pathlib import Path

import colorcet as cc
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
from bids import BIDSLayout
from nilearn.image import math_img
from nilearn.masking import apply_mask
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

from opioids_analysis.multimodal import (
    load_instant_velocity,
    compute_moving_time_percentage,
)
from opioids_analysis.pearson import (
    read_session_subject_level_pearson,
    read_session_group_level_pearson,
    read_session_sample_masks,
    _matrix2tril,
    _tril2matrix,
)
from opioids_analysis.plotting import plot_correlation_matrix

## 1 Analysis parameters

The following analysis parameters must be set:

- `params_path`: path to the folder containing the `opioids_template.nii.gz` template
  file, and the `Mask-autoROIs18-slim.{csv, nii.gz}` ROI files.
- `opioids_dataset_root`: root path of the Opioids dataset, containing:
  - `derivatives/registration/rawdata`: registered and resampled dataset.
  - `derivatives/tracking`: tracking data.
  - `derivatives/phosphorylation_data.xlsx`: phosphorylation data.
  - `derivatives/hot_plate_data.xlsx`: hot plate data.
  - `derivatives/respiratory_data.xlsx`: respiratory depression data.
- `opioids_results_root`: root path of the opioids analysis results, containing:
  - `sample_masks.h5`: sample masks result file.
  - `subject_level_pearson.h5`: subject-level result file of the Pearson correlation
    analysis.
  - `group_level_pearson.h5`: group-level result file of the Pearson correlation
    analysis.
- `sessions`: session labels to analyze.
- `subject_level_path`: path to the HDF5 file where the subject-level results of the
  Pearson correlation analysis are stored.
- `group_level_path`: path to the HDF5 file where the group-level results of the
  Pearson correlation analysis are stored.
- `tracking_root`: path to the folder where the Mobile HomeCage tracking data is saved.
- `registered_root`: path to the registered and resampled (but not preprocessed)
  fUSI-BIDS dataset.
- `phosphorylation_path`: path to the Excel file containing the phosphorylation
  quantification data.
- `analgesia_path`: path to the Excel file containing the hot-plate test data.
- `respiratory_path`: path to the Excel file containing the respiratory depression data.
- `output_path`: path to the output folder where analysis results will be saved.

In [None]:
params_path = Path("../params/")

opioids_dataset_root = Path("/mnt/feanor/datasets/opioids/")
opioids_results_root = Path("/mnt/feanor/home/sdiebolt/opioids-paper-results/")

sessions = ["WTM10", "WTM20", "WTM30", "WTM70"]

sample_masks_path = opioids_results_root / "sample_masks.h5"
subject_level_path = opioids_results_root / "subject_level_pearson.h5"
group_level_path = opioids_results_root / "group_level_pearson.h5"

tracking_root = opioids_dataset_root / "derivatives" / "tracking"

registered_root = opioids_dataset_root / "derivatives" / "registration" / "rawdata"

phosphorylation_path = (
    opioids_dataset_root / "derivatives" / "phosphorylation_data.xlsx"
)

analgesia_path = opioids_dataset_root / "derivatives" / "hot_plate_data.xlsx"

respiratory_path = opioids_dataset_root / "derivatives" / "respiratory_data.xlsx"

output_path = opioids_results_root / "multimodal"
output_path.mkdir(exist_ok=True, parents=True)

figures_path = output_path.parent / "figures"
output_path.mkdir(exist_ok=True, parents=True)

## 2 Functional connectivity index


The functional connectivity index is computed by compyting how much each subject-level
correlation matrix at each phase ressembles the Opioids fingerprint matrix. This
fingerprint matrix is computed by singular value decomposition (SVD) of group-level
matrices of the morphine 70 mg/kg cohort.

In [None]:
group_level_results = read_session_group_level_pearson(group_level_path, "WTM70")

### 2.1 Compute the SVD of the group-level morphine 70 mg/kg matrices

In [None]:
group_level_matrices = group_level_results["correlation_matrices"][:, 0]
group_level_lower_triangle = _matrix2tril(group_level_matrices)

u, s, vh = np.linalg.svd(group_level_lower_triangle, full_matrices=False)

mode0 = _tril2matrix(vh[0])
mode1 = _tril2matrix(vh[1])

### 2.2 Visualize the first two modes

Since group-level matrices were not centered prior to computing the SVD, the first
singular vector corresponds to a "baseline" mode, while the second singular vector
corresponds to our Opioids fingerprint.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(5, 3), dpi=200)

vmax = np.abs(mode0).max()
_ = plot_correlation_matrix(
    axes[0],
    correlation_matrix=mode0,
    vmax=vmax,
    vmin=-vmax,
    cmap=cc.cm.coolwarm,
)
axes[0].set_title("mode 0")
axes[0].axis("off")

vmax = np.abs(mode1).max()
_ = plot_correlation_matrix(
    axes[1],
    correlation_matrix=mode1,
    vmax=vmax,
    vmin=-vmax,
    cmap=cc.cm.coolwarm,
)
axes[1].set_title("mode 1")
_ = axes[1].axis("off")

### 2.3 Compute the FC index across all sessions

The FC index is computed by correlation of `mode1` with all subject-level correlation
matrices across the `salineControl` and all morphine sessions. The `salineControl`
session is defined as a merge of the `saline`, `saline2`n `WTFS1`, and `WTMS1` sessions.

In [None]:
saline_control_sessions = ["saline", "saline2", "WTFS1", "WTMS1"]

mode_ref = mode1
fc_indices = {}
for session in saline_control_sessions + sessions:
    subject_level_results = read_session_subject_level_pearson(
        subject_level_path, session
    )
    subject_level_matrices = subject_level_results["correlation_matrices"]
    fc_index = np.array(
        [
            [
                np.corrcoef(_matrix2tril(phase_matrix), _matrix2tril(mode_ref))[0, 1]
                for phase_matrix in phase_matrices
            ]
            for phase_matrices in subject_level_matrices.values()
        ]
    )
    fc_index -= fc_index[:, :2].mean()
    fc_indices[session] = fc_index

fc_indices["salineControl"] = np.concatenate(
    [fc_indices[s] for s in saline_control_sessions], axis=0
)

fc_index_df = pd.DataFrame({k: v.mean(0) for k, v in fc_indices.items()}).loc[
    :, "WTM10":
]
fc_index_df = fc_index_df[["salineControl", *sessions]]

fc_index_df.index = pd.to_timedelta((np.arange(-2, 6) * 10 + 5) * 60, unit="s")

fc_index_df.to_csv(output_path / "fc_index.csv")

### 2.4 Visualize the FC index

In [None]:
fig, axes = plt.subplots(len(sessions), 1, figsize=(15, 9), sharex=True, dpi=200)

cmap = cc.cm["CET_D1"]
cmap0 = cmap([0, 0.6, 0.7, 0.8, 0.9])

control_average_fc_index = fc_indices["salineControl"].mean(0)
control_std_fc_index = 2 * fc_indices["salineControl"].std(0, ddof=1)

for i, session in enumerate(sessions):
    axes[i].plot(control_average_fc_index, c=cmap0[0], label="salineControl")
    axes[i].errorbar(
        np.arange(8),
        control_average_fc_index,
        yerr=control_std_fc_index,
        capsize=5,
        color=cmap0[0],
    )
    axes[i].scatter(np.arange(8), control_average_fc_index, color=cmap0[0])

    average_fc_index = fc_indices[session].mean(0)
    std_fc_index = 2 * fc_indices[session].std(0, ddof=1)

    axes[i].plot(average_fc_index, c=cmap0[i + 1], label=session)
    axes[i].errorbar(
        np.arange(8),
        average_fc_index,
        yerr=std_fc_index,
        capsize=5,
        color=cmap0[i + 1],
    )
    axes[i].scatter(np.arange(8), average_fc_index, color=cmap0[i + 1])

    axes[i].axvline(1.5, c="k", ls=":", lw=0.5, zorder=3)

    axes[i].set_ylabel("$r$-$r_0$")
    axes[i].legend(loc="upper left")
    axes[i].set_ylim(-1, 1)


_ = axes[-1].set_xlabel("Phases")
_ = axes[-1].set_xticks(
    np.arange(8), ["BP1", "BP2", "SP1", "SP2", "SP3", "SP4", "SP5", "SP6"]
)

## 3 Animal motion

Tracking data is obtained via the Mobile HomeCage recording. The mouse instant velocity
is re-computed from its resampled position. The percentage of time in the "moving" state
in each window corresponding to the different acquisition phases (e.g. BP1, BP2, SP1,
etc.) is computed by using a threshold of 5 cm/s to differentiate between moving and
resting.

In [None]:
tracking_layout = BIDSLayout(tracking_root, validate=False)

mobility = {}
for session in tqdm(["WTSC11"] + sessions):
    subjects = tracking_layout.get_subjects(session=session)

    mobility[session] = []
    for subject in subjects:
        tracking_paths = sorted(
            tracking_layout.get(
                subject=subject, session=session, extension=".h5", return_type="file"
            )
        )

        if len(tracking_paths) == 3:
            speeds = []
            for file in tracking_paths:
                speeds.append(load_instant_velocity(file))

            mobility[session].append(
                compute_moving_time_percentage(np.concatenate(speeds))
            )

    mobility[session] = np.array(mobility[session])
    mobility[session] -= mobility[session][:, :2].mean()

# Tracking data is averaged across subjects and concatenated into a pandas DataFrame for
# easier comparison with all other readouts.
mobility_df = pd.concat(
    [
        pd.DataFrame(m.mean(0, keepdims=True).T, columns=[s])
        for s, m in mobility.items()
    ],
    axis=1,
)

mobility_df.index = pd.to_timedelta((np.arange(-2, 6) * 10 + 5) * 60, unit="s")
mobility_df.rename(columns={"WTSC11": "salineControl"}, inplace=True)

mobility_df.to_csv(output_path / "mobility.csv")

## 4 CBV

The cerebral blood volume (CBV) is computed as the average power Doppler signal inside a
brain mask. 

In [None]:
registered_layout = BIDSLayout(registered_root, validate=False)

qform = np.diag([0.11, 0.4, 0.1, 1])

template_name = params_path / "opioids_template.nii.gz"
template_img = nib.load(template_name)
template_img.set_qform(qform, code=1)
template_img.set_sform(qform, code=0)

# Ignore warning about casting to int32.
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    brain_mask_img = math_img("(img != 0).astype(int)", img=template_img)

rois_path = params_path / "Mask-autoROIs18-slim.nii.gz"
rois_img = nib.load(rois_path)
rois_img.set_qform(qform, code=1)
rois_img.set_sform(qform, code=0)

rois_names_path = params_path / "Mask-autoROIs18-slim.csv"
rois_names = pd.read_csv(rois_names_path)

n_rois = np.unique(rois_img.get_fdata()).size - 1

# Ignore warning about casting to int32.
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    seed_masks = [
        math_img(f"(img == {roi + 1}).astype(int)", img=rois_img)
        for roi in range(n_rois)
    ]

In [None]:
cbvs = {}
for session in tqdm(saline_control_sessions + sessions):
    subjects = registered_layout.get_subjects(session=session)

    session_sample_masks = sample_mask = read_session_sample_masks(
        sample_masks_path, session
    )

    subject_cbvs = []
    for subject in subjects:
        nii_paths = sorted(
            registered_layout.get(
                subject=subject,
                session=session,
                extension=".nii.gz",
                return_type="file",
            )
        )

        time_series = np.concatenate(
            [apply_mask(path, brain_mask_img) for path in nii_paths], axis=0
        )

        subject_sample_mask = session_sample_masks[subject].reshape((-1,))
        time_series[~subject_sample_mask] = np.nan

        baseline_average = np.nanmean(time_series[:2400, :], axis=0)

        # Ignore division by zero warnings when normalizing the time series.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            time_series = (time_series - baseline_average) / baseline_average

        time_series[np.isinf(time_series)] = 0

        subject_cbvs.append(time_series)

    # Ignore mean of empty slice warnings when computing the mean CBV signal.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        cbvs[session] = np.nanmean(np.array(subject_cbvs), axis=(0, 2))

cbv_df = pd.DataFrame(cbvs)
cbv_df.index = pd.to_timedelta((cbv_df.index - 2400) / 2, unit="s")

cbv_df["salineControl"] = np.average(
    cbv_df[saline_control_sessions],
    weights=[
        len(registered_layout.get_subjects(session=session))
        for session in saline_control_sessions
    ],
    axis=1,
)

# We smooth the CBV signal using a 5 minutes window to smooth out spikes and try to get
# a better fit during detrending.
cbv_df = cbv_df.interpolate().rolling(300, center=True).mean().dropna()

p = np.polyfit(cbv_df.dropna().index.total_seconds(), cbv_df.dropna(), 2)
cbv_df = cbv_df - np.polyval(p, cbv_df.index.total_seconds().to_numpy()[:, None])

cbv_df.to_csv(output_path / "cbv.csv")

## 5 Phosphorylation

Phosphorylation data is extracted from an Excel sheet provided by Andrea Kliewer.

In [None]:
phosphorylation_df = pd.read_excel(
    phosphorylation_path,
    skiprows=16,
    skipfooter=12,
    names=["time"] + [f"{s}_{i}" for s in sessions for i in range(1, 3)],
).drop(0)
phosphorylation_df.time *= 60

phosphorylation_df.index = pd.to_timedelta(phosphorylation_df.time, unit="s")
phosphorylation_df.drop(columns=["time"], inplace=True)

# We average across runs for each dose.
phosphorylation_df = (
    phosphorylation_df.T.groupby(
        phosphorylation_df.columns.str.extract(r"(WTM\d{2})_(?:1|2)", expand=False)
    )
    .mean()
    .T
)
phosphorylation_df["salineControl"] = 0

phosphorylation_df.to_csv(output_path / "phosphorylation.csv")

## 6 Analgesia

Analgesia data, obtained as a result from hot-plate tests, is extracted from an Excel
sheet provided by Andrea Kliewer.

In [None]:
analgesia_df = pd.read_excel(
    analgesia_path,
    skiprows=22,
)
analgesia_df.rename(columns={"Time (min)": "time"}, inplace=True)
analgesia_df.time *= 60
analgesia_df.index = pd.to_timedelta(analgesia_df.time, unit="s")
analgesia_df.drop(columns=["time"], inplace=True)

# Renaming columns and dropping null columns
group_num = 0
for col_index, col in enumerate(analgesia_df.columns):
    if analgesia_df[col].notnull().any():
        if analgesia_df.iloc[:, col_index - 1].isnull().all():
            group_num += 1
        analgesia_df.rename(columns={col: sessions[group_num]}, inplace=True)
analgesia_df = analgesia_df.dropna(axis=1)

# The average across subjects in computed to compare with other readouts.
analgesia_df = analgesia_df.T.groupby(by=analgesia_df.columns).mean().T
analgesia_df["salineControl"] = 0

analgesia_df.to_csv(output_path / "analgesia.csv")

## 7 Respiratory

Respiratory depression data is extracted from an Excel sheet provided by Andera Kliewer.

In [None]:
# The sampling frequency of the respiratory data is 0.5 Hz, as indicated in the
# Excel file.
respiratory_sampling_frequency = 1 / 2

respiratory_morphine = pd.read_excel(
    respiratory_path,
    sheet_name=1,
    skiprows=12,
).dropna(axis=1)
respiratory_morphine.columns = [
    f"WTM{dose:02d}" for dose in (30, 20, 10, 5, 70) for _ in range(6)
]

respiratory_control = pd.read_excel(
    respiratory_path,
    sheet_name=0,
    skiprows=11,
).dropna(axis=1)
respiratory_control.columns = ["WT"] * 6

respiratory_df = pd.concat((respiratory_control, respiratory_morphine), axis=1)
respiratory_df.index = pd.to_timedelta(respiratory_df.index * 2, unit="s")

# Interpolate outliers.
respiratory_df[respiratory_df < 40] = np.nan
respiratory_df[respiratory_df > 270] = np.nan
respiratory_df.interpolate(inplace=True)

# Smooth data using a window of size 10 minutes to match the duration of the phases.
respiratory_df = (
    respiratory_df.rolling(int(60 * respiratory_sampling_frequency * 10), center=True)
    .mean()
    .dropna()
)

# Transform to BPM relative to control.
respiratory_df = respiratory_df - np.tile(respiratory_df["WT"].mean().to_numpy(), 6)

# The average across subjects in computed to compare with other readouts.
respiratory_df = respiratory_df.T.groupby(by=respiratory_df.columns).mean().T

respiratory_df.to_csv(output_path / "respiratory.csv")

### 7.1 Visualize respiratory depression data

In [None]:
control_respiratory = respiratory_df.loc[:, "WT_1":"WT_6"]

plt.figure(figsize=(10, 3), dpi=200)

for col in respiratory_df:
    plt.plot(respiratory_df.index.seconds / 60, respiratory_df[col], label=col)

plt.xlabel("Time (min)")
plt.ylabel(r"$\Delta$ Respiration (bpm)")
_ = plt.legend()

## 8 Principal component analysis

### 8.1 Visualize correlation matrices of physiological readouts for saline and morphine 

In [None]:
session_labels = {
    "salineControl": "Saline",
    "WTM10": "Morphine (10 mg/kg)",
    "WTM20": "Morphine (20 mg/kg)",
    "WTM30": "Morphine (30 mg/kg)",
    "WTM70": "Morphine (70 mg/kg)",
}

for session in session_labels:
    data = np.stack(
        (
            fc_index_df[session].to_numpy(),
            mobility_df[session].to_numpy(),
            cbv_df.loc[fc_index_df.index, session].to_numpy(),
            np.insert(
                analgesia_df.resample("300s")
                .mean()
                .interpolate()
                .loc[fc_index_df.index[2:], session]
                .to_numpy(),
                0,
                (0, 0),
            ),
            np.insert(
                phosphorylation_df.resample("300s")
                .mean()
                .interpolate()
                .loc[fc_index_df.index[2:], session]
                .to_numpy(),
                0,
                (0, 0),
            ),
        ),
        axis=1,
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3), dpi=300, facecolor="none")
    im = ax1.imshow(
        StandardScaler().fit_transform(data), vmax=2, vmin=-2, cmap="coolwarm"
    )
    fig.colorbar(im, ax=ax1, label=r"$z$-score")
    ax1.set_xticks(
        ticks=np.arange(5),
        labels=["FC", "Mobility", "CBV", "Analgesia", "Phosphorylation"],
        rotation=45,
        ha="right",
    )
    ax1.set_yticks(
        ticks=np.arange(8), labels=["BP1", "BP2"] + [f"SP{i}" for i in range(1, 7)]
    )
    ax1.set_title("Readouts")

    # Ignore invalid value warnings for the saline session (missing analgesia and
    # phosphorylation data).
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        im = ax2.imshow(np.corrcoef(data.T), cmap="coolwarm", vmin=-1, vmax=1)
    fig.colorbar(im, ax=ax2, label="Pearson correlation")
    ax2.set_xticks(
        ticks=np.arange(5),
        labels=["FC", "Mobility", "CBV", "Analgesia", "Phosphorylation"],
        rotation=45,
        ha="right",
    )
    ax2.set_yticks(
        ticks=np.arange(5),
        labels=["FC", "Mobility", "CBV", "Analgesia", "Phosphorylation"],
    )
    ax2.set_title("Correlation matrix")

    _ = fig.suptitle(session_labels[session], x=0.55, y=1.1, fontweight="bold")

    fig.savefig(figures_path / f"readouts_{session}.tiff", dpi=300, bbox_inches="tight")

### 8.2 Visualize the projection on PC1 and PC2

In [None]:
pca = make_pipeline(StandardScaler(), PCA(n_components=3))

session = "WTM70"
data = np.stack(
    (
        fc_index_df[session].to_numpy(),
        mobility_df[session].to_numpy(),
        cbv_df.loc[fc_index_df.index, session].to_numpy(),
        np.insert(
            analgesia_df.resample("300s")
            .mean()
            .interpolate()
            .loc[fc_index_df.index[2:], session]
            .to_numpy(),
            0,
            (0, 0),
        ),
        np.insert(
            phosphorylation_df.resample("300s")
            .mean()
            .interpolate()
            .loc[fc_index_df.index[2:], session]
            .to_numpy(),
            0,
            (0, 0),
        ),
    ),
    axis=1,
)
pca.fit(data)

session_labels = {
    "salineControl": "Saline",
    "WTM10": "Morphine (10 mg/kg)",
    "WTM20": "Morphine (20 mg/kg)",
    "WTM30": "Morphine (30 mg/kg)",
    "WTM70": "Morphine (70 mg/kg)",
}
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=300)
for session in ["salineControl"] + sessions:
    data = np.stack(
        (
            fc_index_df[session].to_numpy(),
            mobility_df[session].to_numpy(),
            cbv_df.loc[fc_index_df.index, session].to_numpy(),
            np.insert(
                analgesia_df.resample("300s")
                .mean()
                .interpolate()
                .loc[fc_index_df.index[2:], session]
                .to_numpy(),
                0,
                (0, 0),
            ),
            np.insert(
                phosphorylation_df.resample("300s")
                .mean()
                .interpolate()
                .loc[fc_index_df.index[2:], session]
                .to_numpy(),
                0,
                (0, 0),
            ),
        ),
        axis=1,
    )

    data_transformed = pca.transform(data)
    points = ax.scatter(
        -data_transformed[:, 0],
        data_transformed[:, 1],
        alpha=np.arange(1, 9) / 8,
        label=session_labels[session],
    )
    phase_labels = ["BP1", "BP2", "SP1", "SP2", "SP3", "SP4", "SP5", "SP6"]
    texts = []
    for i in range(data_transformed.shape[0]):
        texts.append(
            plt.text(
                -data_transformed[i, 0],
                data_transformed[i, 1],
                phase_labels[i],
                fontsize=5,
            )
        )

ax.add_patch(
    patches.Rectangle((1, -1.1), 1, 1.2, linewidth=0.5, edgecolor="k", facecolor="none")
)
ax.text(2.1, -2, "Highest FC/Analgesia/Phosphorylation", fontsize=5, rotation=90)

ax.add_patch(
    patches.Rectangle(
        (-3.45, -1.3), 1.1, 1.1, linewidth=0.5, edgecolor="k", facecolor="none"
    )
)
ax.text(-3.7, -2.15, "Lowest FC/Analgesia/Phosphorylation", fontsize=5, rotation=90)

ax.add_patch(
    patches.Rectangle(
        (-1.3, 2.95), 1.2, 1.2, linewidth=0.5, edgecolor="k", facecolor="none"
    )
)
ax.text(-1.5, 2.75, "Highest CBV/Mobility", fontsize=5)

ax.arrow(-3.1, -3, 6, 0, head_width=0.1, edgecolor="none", facecolor="k", width=0.02)
ax.text(-1.2, -3.4, "Increasing dose", fontsize=10)


ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_xlim(-4.3, 4.3)
ax.set_ylim(-4.3, 4.3)
ax.set_title("Projection on two principal components")
ax.legend(fontsize=8)

fig.savefig(figures_path / "projection.tiff", dpi=300, bbox_inches="tight")

### 8.3 Visualize PC loadings and explained variances

In [None]:
loadings = pca["pca"].components_.T * np.sqrt(pca["pca"].explained_variance_)
loadings[:, 0] *= -1

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3), dpi=300, facecolor="none")
im = ax1.imshow(loadings[:, :2], vmax=1, vmin=-1, cmap="coolwarm")
ax1.set_xticks(ticks=np.arange(2), labels=["PC1", "PC2"])
ax1.set_yticks(
    ticks=np.arange(5), labels=["FC", "Mobility", "CBV", "Analgesia", "Phosphorylation"]
)
ax1.set_title("Loadings")
fig.colorbar(im, ax=ax1, label="Loading (a.u.)", ticks=[-1, -0.5, 0, 0.5, 1])

explained_variance_ratio = 100 * pca["pca"].explained_variance_ratio_

rng = np.random.default_rng(seed=500)
random_pca = make_pipeline(StandardScaler(), PCA(n_components=3))
random_pca.fit(rng.normal(size=data.shape))
random_variance_ratio = 100 * random_pca["pca"].explained_variance_ratio_

ax2.scatter(np.arange(3), explained_variance_ratio, label="Actual data")
ax2.plot(np.arange(3), explained_variance_ratio)
ax2.scatter(np.arange(3), random_variance_ratio, label="White noise")
ax2.plot(np.arange(3), random_variance_ratio)
ax2.set_xticks(ticks=np.arange(3), labels=["PC1", "PC2", "PC3"])
ax2.set_ylim(0, 100)
ax2.set_ylabel("Explained variance ratio (%)")
ax2.axvline(1.5, color="red")
ax2.text(1.53, 30, "Marchenko-Pastur limit", color="red", rotation=90)
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position("right")
ax2.set_title("Explained variance")
_ = ax2.legend()

fig.savefig(figures_path / "loadings.tiff", dpi=300, bbox_inches="tight")