In [None]:
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
import matplotlib.pyplot as plt

plt.subplots()
# hack to remove hide globally installed libraries, which are the wrong R version
from rpy2 import robjects as ro

ro.r(".libPaths('/local/scratch/noddi/lib/R/library')")

In [None]:
import itertools as it
from pathlib import Path

import ciftipy as cp
import colormaps as cmaps
import dask.bag as db
import nibabel as nb
import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as cs
import scipy.stats as scs
import seaborn as sns
import seaborn.objects as so
import statsmodels.formula.api as smf
import statsmodels.api as sm
import xarray as xr
from brainspace.plotting import plot_surf, plot_hemispheres
from brainstat.datasets import fetch_mask, fetch_template_surface
from brainstat.stats.SLM import SLM
from brainstat.stats.terms import FixedEffect, MixedEffect
from dask.diagnostics import ProgressBar
from matplotlib import cm, font_manager
from rsbids import BidsLayout
from statannotations.Annotator import Annotator
import rpy2_arrow.polars as rpy2polars
import colormaps as cmaps
from cycler import cycler

from lib import atlases
from lib.bidsarray import layout_map
from lib.dataset import Dataset
from lib.mesh import mesh_smooth
from lib.demographics import DemographicTable
from lib import polars_expr as ple
from lib.plotting import (
    add_colorbar,
    add_legend,
    comparison_plot,
    plot_hierachical_connectome,
)
from lib.seaborn_stats import Lme4CI, PearsonrAnnot, PolyCI
from styles import styles as Styles

%load_ext autoreload
%autoreload 2
%matplotlib inline
plt.style.use("styles/elsevier.mplstyle")
font_dirs = [Path.home() / ".fonts"]
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)
plt.rcParams["axes.prop_cycle"] = cycler(color=cmaps.colorblind_10.colors)
so.Plot.config.theme.update(plt.rcParams)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

In [None]:
import rpy2
from templateflow import api as tflow
import rpy2.ipython.html
from rpy2 import robjects as ro
from rpy2.robjects.packages import importr

rpy2.ipython.html.init_printing()


rutils = importr("utils")
rbase = importr("base")
lme4 = importr("lme4")
rstats = importr("stats")
pbkrtest = importr("pbkrtest")
%load_ext rpy2.ipython

%R library(polars); library(tidypolars); library(tidyverse); library("lme4")

In [None]:
from brainspace.mesh import mesh_io
# layout = BidsLayout(
#     ["../../derivatives/surfsample-0.1.0/", "../../derivatives/snakeanat-diffusion-v0.0.1/"],
#     cache=".cache",
#     reset_cache=True,
# )
hcp = (
    Dataset(".hcp.layout", "hcp")
    .add_phenotypes("hcp_metadata.yaml")
    .filter(pl.col("ddx").is_in([1, 3]))
)
lh, rh = fetch_template_surface("fslr32k", layer="vinflated", join=False)
def get_sphere(hemi):
    path = tflow.get(template="fsLR", density="32k", suffix="sphere", space=None, hemi=hemi)
    return mesh_io.read_surface(str(path))
lh_sphere = get_sphere("L")
rh_sphere = get_sphere("R")
mesh = fetch_template_surface("fslr32k", layer="inflated")
mask = fetch_mask("fslr32k")

hcp.metadata = hcp.metadata.with_columns(
    pl.sum_horizontal("PANSSP2", "PANSSN5", "PANSSN6", "PANSSN7", "PANSSG11").alias(
        "Disorganization"
    )
)

# Initialize Data

## Gather all the surface sample files

In [None]:
def get_hems(img):
    l = img[img.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project(0).flatten()
    r = img[img.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project(0).flatten()
    return l, r


@layout_map(parallel=True, dtype=float, dims={"vertex": 64984})
def load_data(path):
    img = cp.load(path)
    lh, rh = get_hems(img)
    data = np.full((lh.shape[0] + rh.shape[0]), np.NaN)
    bound = lh.shape[0]
    data[:bound] = lh
    data[bound:] = rh
    return data

    # jhp_surface = load_data(
    #     jhp.layout.get(suffix=["curv", "thickness"], den="32k"),
    #     ["subject", "session", "suffix"],
    # ).to_dataset(name="surface")
    # jhp_surface.to_netcdf("checkpoint2.h5")


hcp_surface = xr.concat(
    [
        load_data(
            hcp.layout.get(suffix="mdp", extension=".dscalar.nii"),
            ["subject", "desc"],
        ),
        load_data(
            hcp.layout.filter(scope="hcp-preproc").get(
                suffix="thickness", den="32k", desc="corr"
            ),
            ["subject", "suffix"],
        ).rename(suffix="desc"),
    ],
    dim="desc",
)

with ProgressBar():
    hcp_surface.to_netcdf("hcp_surface.nc")

In [None]:
hcp_surface

In [None]:
hcp_surface = xr.open_dataarray("hcp_surface.nc", chunks={})
hcp_surface_smooth = xr.concat(
    [
        mesh_smooth(
            hcp_surface.where(mask),
            surf=mesh,
            FWHM=smoothing,
            mask=mask,
            axis="vertex",

        ).expand_dims(smoothing=[smoothing])
        for smoothing in np.r_[5:15]
    ],
    dim="smoothing"
)


In [None]:
with ProgressBar():
    hcp_surface_smooth.to_netcdf("hcp_surface_smooth.nc")

## Do sampling

In [None]:
dkmd = pd.read_csv("atlas-dkt_labels.tsv", sep="\t")
hcp_surface = xr.load_dataarray("hcp_surface.nc")


def do_sampling(ds):
    @layout_map(
        parallel=True,
        dims={"param": ds["desc"].rename(desc="param"), "roi": dkmd["label"]},
        dtype=float,
    )
    def sample_dk(path):
        dknii = cp.load(path)
        dkatlas = np.empty(32492 * 2)
        dkatlas[:32492] = (
            dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project().ravel()
        )
        dkatlas[32492:] = (
            dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project().ravel()
        )
        result = np.empty((len(ds["desc"]), len(dkmd)))
        x = ds.sel(subject=path.entities["subject"])
        for (i, param), (j, label) in it.product(
            enumerate(ds["desc"]), enumerate(dkmd["label"])
        ):
            result[i, j] = np.mean(x.sel(desc=param)[dkatlas == label])
        return result


    return (
        sample_dk(
            hcp.layout.get(
                suffix="dparc",
                subject=ds["subject"].data,
                atlas="dk",
                space="fsLR",
                den="32k",
            ),
            wildcards=["subject"],
        )
        .to_dataset(name="dk")
        .merge(dkmd.rename(columns={"label": "roi"}).set_index("roi").to_xarray())
    )
with ProgressBar():
    do_sampling(hcp_surface).to_netcdf("hcp_sample.nc")

# Analysis

In [None]:
hcp_surface = xr.load_dataarray("hcp_surface.nc").drop_sel(subject="1032")
hcp_sampled = xr.load_dataset("hcp_sample.nc").drop_sel(subject="1032")
# hcp_sampled = xr.load_dataset("hcp_dk_sample_fullwdith.nc")


def get_stds(df):
    return df.with_columns(
        stds=(
            pl.col.dk.map_elements(scs.zscore, return_dtype=pl.List(float))
            .abs()
            .over("lobe", "param", "group")
        )
    ).with_columns(
        pl.col.stds.floor().cast(int).sub(1).clip(0, 5).sum().over("subject")
    )


hcp_df = (
    # xr.concat([pial_fw_sampled, hcp_sampled.sel(param="thickness")], dim="param")
    hcp_sampled
    .to_dataframe()
    .reset_index()
    .pipe(pl.from_pandas)
    .filter(~pl.col.roi.is_in([4, 39]))
    .group_by("subject", "param", "lobe", "hemisphere")
    .agg(pl.mean("dk"))
    .filter(pl.col.subject != "1032")
    # .filter(pl.col.param.is_in(["FA", "ndi", "odi", "fw", "thickness"]))
    .join(
        hcp.metadata[
            [
                "subject",
                "group",
                "age",
                "sex",
                "PANSSP",
                "PANSSN",
                "Disorganization",
                "antipsychotic_dur",
            ]
        ],
        on="subject",
    )
    .pipe(get_stds)
    # .filter(~pl.col.subject.is_in(["2004", "4010"]))
)
std_thresh = 30
dropped_subs = list(hcp_df.filter(pl.col.stds >= std_thresh)["subject"].unique())
# dropped_subs = []
hcp_df = hcp_df.filter(pl.col.stds < std_thresh)
hcp_smooth = xr.open_dataarray("hcp_surface_smooth.nc", chunks={}).drop_sel(
    subject=[*dropped_subs, "1032"]
)

### Demographics

In [None]:
hcp_included = pl.col.subject.is_in(hcp_df["subject"])

hcp_demo_all = (
    Dataset(".hcp.layout", "hcp", prefilter=False)
    .add_phenotypes("hcp_metadata.yaml")
    .filter(pl.col("ddx").is_in([1, 3]))
)
hcp_demo = hcp_demo_all.filter(hcp_included)

In [None]:
def capitalize(label):
    return label.replace("_", " ").capitalize()


def prepare_session_table_hcp(table):
    table.add_nominal("sex", "{M}/{F}", autoformatter=capitalize)
    table.add_scale("education", autoformatter=capitalize)
    table.add_scale("age", autoformatter=capitalize)
    table.add_nominal("handedness", "{R}/{L}", autoformatter=capitalize)
    table.add_scale("SES")
    table.add_nominal("smoke", "{Yes}/{No}", autoformatter=capitalize)
    table.add_scale("pack_years", autoformatter=capitalize)
    table.add_nominal(
        "cannabis_use", "{Yes}/{No}", "Cannabis User", skip_stats=True, skip_fields=["HC"]
    )
    table.add_scale(
        "antipsychotic_dur",
        "Antipsychotic Duration (months)",
        report="median",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale(
        "total_defined_daily_dose",
        "Lifetime Antipsychotics (Defined Daily Dose)",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale("panss-total", "PANSS Total", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSSP", "PANSS Positive", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSSN", "PANSS Negative", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSS-G", "PANSS Global", skip_stats=True, skip_fields=["HC"])


parts = []
table = DemographicTable(
    hcp_demo.metadata.with_columns(
        # pl.col.smoker.fill_null("no"),
        pl.col.antipsychotic_dur.fill_null(0),
        pl.col.total_defined_daily_dose.fill_null(0),
        # pl.col.ethnicity.replace(TOPSY_ETHNICITIES),
        cannabis_use=pl.col.cannabis_exposure.replace({1: "No", 3: "Yes"}),
        pack_years=pl.col.smoke_time * pl.col.smoke_amount / 20,
    ).to_pandas(),
    "group",
    ["HC", "Patient"],
    flavour="latex",
)
prepare_session_table_hcp(table)

print(
    table.to_pandas()
    .style.to_latex(
        column_format="rllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
    )
)

### Global Surface Analysis

In [None]:
from statsmodels.stats import multitest

from lib import polars_expr as ple


def get_global_avg(ds):
    hem_slices = [np.s_[:32492], np.s_[32492:]]
    result = np.empty((len(hcp_surface["subject"]), len(hcp_surface["desc"]), 2))
    for (i, subject), (j, param), hem in it.product(
        enumerate(ds["subject"]), enumerate(ds["desc"]), range(2)
    ):
        result[i, j, hem] = np.mean(
            ds.sel(subject=subject, desc=param)[hem_slices[hem]]
        )

    return xr.DataArray(
        result,
        coords={"subject": ds["subject"], "desc": ds["desc"], "hemisphere": ["L", "R"]},
    ).rename(desc="param")


hcp_glob_df = (
    # xr.concat([pial_fw_sampled, hcp_sampled.sel(param="thickness")], dim="param")
    get_global_avg(hcp_surface)
    .to_dataframe(name="dk")
    .reset_index()
    .pipe(pl.from_pandas)
    .join(
        hcp.metadata[
            [
                "subject",
                "group",
                "age",
                "sex",
                "PANSSP",
                "PANSSN",
                "Disorganization",
                "antipsychotic_dur",
            ]
        ],
        on="subject",
    )
    .filter(~pl.col.subject.is_in([*dropped_subs, "1032"]))
    # .pivot(
    #     index=cs.exclude("param", "dk"),
    #     columns="param",
    #     values="dk",
    # )
)
hcp_glob_df

In [None]:
%%R -i hcp_glob_df -c rpy2polars.converter
library(lmerTest)
df <- hcp_glob_df |>
    as_tibble() |>
    pivot_wider(names_from=param, values_from=dk)
    # group_by(subject) |>
    # summarise(
    #     across(where(is.character) & ! hemisphere, first),
    #     across(!where(is.character), mean),
    # )
df
lm1 <- lmer(fw ~ age + sex + group*hemisphere + csf + (1|subject), data=df)
# lm.m <- lm(fw ~ scale(age) + sex + csf + group, data=df)
# lm.mediated <- glm(as.factor(group) ~ scale(age) + sex + fw + thickness, data=df, family=binomial)
summary(lm1)

In [None]:
hcp_glob_stats = (
    hcp_glob_df.group_by("param")
    .agg(
        ple.lmer("scale(dk) ~ scale(age) + sex + hemisphere*group + (1|subject)").alias(
            "stats"
        )
    )
    .unnest("stats")
    .explode("table")
    .unnest("table")
    .filter(
        pl.col.index.is_in(
            ["hemisphereR", "groupPatient", "hemisphereR:groupPatient"]
        )
    )
    .with_columns(
        pval=pl.when(
            pl.col.index == "groupPatient",
            pl.col.param.is_in(["ndi", "odi", "fw", "thickness", "csf"]),
        )
        .then(pl.col.pval / 2)
        .otherwise(pl.col.pval)
    )
    .with_columns(
        pvalcorr=pl.col.pval.map_elements(
            lambda a: pl.Series(multitest.multipletests(a, method="holm")[1]),
            return_dtype=pl.List(float),
        ).over("param"),
    )
)
hcp_glob_stats

In [None]:

(
    hcp_glob_stats
    .with_columns(pl.col.param.replace({"L1": "AD", "fw": "v_iso"}))
    .sort(["param", "index"])
    .write_excel("suppl/hcp-glob-stats.xlsx")
)

In [None]:
#| fig-cap: Cortical thickness is reduced in patients. Statistics computed using linear
#|   mixed-effects models with subject as a random effect and group, hemisphere, age,
#|   and sex as fixed effects. For each parameter, two-tailed T-tests were used to evaluate
#|   the effect of hemisphere and group-hemisphere interaction on the parameter. One-tailed
#|   T-tests were used for group effects on NDI, ODI, thickness (reduced in patients),
#|   and $v_{iso}$ (increased in patients). Degrees of freedom for each comparison
#|   was estimated using Satterthwaite's method [@satterthwaiteApproximateDistributionEstimates1946].
#|   P-values from the three contrasts were corrected using Holm-Bonferonni corrections.
#|   A. Average cortical thickness is significantly lower ($T(110.4)=-4.8;P<.001$) in
#|   patients versus controls. B. Average parameter values split by hemisphere. Lines
#|   illustrate the per-group mean differences across hemispheres. Shaded bands represent
#|   95% CI computed by parametric boostrapping of the mixed-effects model with 1000
#|   replicates. NDI is significantly lower in the right hemisphere ($T(103.0) = 3.5;P=.002$).
#|   No other hemisphere or group:hemisphere interaction effects are noted.
#| label: fig-surf-global
_df = hcp_glob_df.pivot(
    index=["subject", "age", "sex", "hemisphere", "group", "PANSSP", "PANSSN"],
    columns="param",
    values="dk",
).with_columns(cs.matches("MD|L1|RD") * 1000)
fig = plt.figure(figsize=(7.48, 6), layout="constrained")
panels = fig.subfigures(2, 1, height_ratios=[3.5, 2.5])
panels[0].text(0.05, 1.02, "A", **Styles.panel_label)
panels[1].text(0.05, 1.02, "B", **Styles.panel_label)


ylabels = {
    "ndi": "NDI",
    "odi": "ODI",
    "fw": r"$\log{v_{iso}}$",
    "thickness": "Thickness (mm)",
    "FA": "FA",
    "MD": r"MD $\left(\frac{\mu m^2}{ms}\right)$",
    "RD": r"RD $\left(\frac{\mu m^2}{ms}\right)$",
    "L1": r"AD $\left(\frac{\mu m^2}{ms}\right)$",
}

axs = panels[0].subplots(2, 2)
params = ["thickness", "fw", "ndi", "odi"]
for i, (y, x) in enumerate(np.ndindex(2, 2)):
    ax = axs[y, x]
    if params[i] in {"thickness"}:
        sig = {
            ("HC", "Patient"): hcp_glob_stats.filter(
                param=params[i], index="groupPatient"
            )["pvalcorr"][0]
        }
    else:
        sig = None
    comparison_plot(
        _df.to_pandas(),
        y=params[i],
        x="group",
        hue="group",
        ax=ax,
        alpha=0.7,
        size=3,
        order=["HC", "Patient"],
        hue_order=["HC", "Patient"],
        significance=sig,
    )
    ax.set_ylabel(ylabels[params[i]])
    ax.set_xlabel("Group" if y == 1 else "")
    ax.set_xticks([1, 0])

axs = panels[1].subplots(2, 2)
params = ["thickness", "fw", "ndi", "odi"]#, "FA", "MD", "RD", "L1"]
for i, (y, x) in enumerate(np.ndindex(2, 2)):
    (
        so.Plot(_df, y=params[i], x="hemisphere", color="group", linestyle="group")
        .add(so.Dot(pointsize=2.5, alpha=0.7, edgealpha=0), so.Jitter(), legend=False)
        .add(so.Line(linewidth=2), so.PolyFit(1), legend=False)
        .add(so.Band(), Lme4CI(nsims=1000), group="subject", legend=False)
        .on(axs[y, x])
        .label(x="" if y == 0 else "Hemisphere", y=ylabels[params[i]])
        .scale(
            x=so.Nominal(order=["L", "R"]), color=so.Nominal(order=["HC", "Patient"])
        )
        .plot()
    )
add_legend(
    panels[1],
    ["HC", "Patient"],
    cmap=cmaps.colorblind_10,
    fontsize=10,
    loc="center",
    bbox_to_anchor=(0.55, 1.1),
    size=2.5,
)
None

In [None]:
#| fig-cap: No DTI changes in patient cortical gray matter. Statistics computed using
#|   linear mixed-effects models with subject as a random effect and group, hemisphere,
#|   age, and sex as fixed effects. For each parameter, two-tailed T-tests were used
#|   to evaluate the effect of group, hemisphere, and group-hemisphere interaction on
#|   the parameter. Degrees of freedom for each comparison was estimated using Satterthwaite's
#|   method [@satterthwaiteApproximateDistributionEstimates1946]. P-values from the three
#|   contrasts were corrected using Holm-Bonferonni corrections. Average parameter values
#|   are represented split by hemisphere. Lines illustrate the per-group mean differences
#|   across hemispheres. Shaded bands represent 95% CI computed by parametric boostrapping
#|   of the mixed-effects model with 1000 replicates.No hemisphere or group:hemisphere
#|   interaction effects are noted.
#| label: fig-surf-global-dti
_df = hcp_glob_df.pivot(
    index=["subject", "age", "sex", "hemisphere", "group", "PANSSP", "PANSSN"],
    columns="param",
    values="dk",
).with_columns(cs.matches("MD|L1|RD") * 1000)
fig = plt.figure(figsize=(7.48, 3), layout="constrained")

ylabels = {
    "FA": "FA",
    "MD": r"MD $\left(\frac{\mu m^2}{ms}\right)$",
    "RD": r"RD $\left(\frac{\mu m^2}{ms}\right)$",
    "L1": r"AD $\left(\frac{\mu m^2}{ms}\right)$",
}

axs = fig.subplots(2, 2)
params = ["FA", "MD", "RD", "L1"]
for i, (y, x) in enumerate(np.ndindex(2, 2)):
    (
        so.Plot(_df, y=params[i], x="hemisphere", color="group", linestyle="group")
        .add(so.Dot(pointsize=3, alpha=0.7), so.Jitter(), legend=False)
        .add(so.Line(linewidth=2), so.PolyFit(1), legend=False)
        .add(so.Band(), Lme4CI(nsims=1000), group="subject", legend=False)
        .on(axs[y, x])
        .label(x="" if y == 0 else "Hemisphere", y=ylabels[params[i]])
        .scale(
            x=so.Nominal(order=["L", "R"]), color=so.Nominal(order=["HC", "Patient"])
        )
        .plot()
    )
add_legend(
    fig,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=10,
    loc="center",
    bbox_to_anchor=(0.55, 0.55),
)
None

In [None]:
p_format = (
    pl.when(pl.col("pvalcorr") < 0.001)
    .then(pl.lit("< .001"))
    .otherwise(
        pl.col("pvalcorr")
        .round(3)
        .round_sig_figs(2)
        .cast(str)
        .str.strip_chars_start("0")
    )
)
sig_format = (
    pl.when(pl.col("pvalcorr") < 0.05)
    .then(pl.format(r"\textbf{{}}", p_format))
    .otherwise(p_format)
)
print(
    hcp_glob_stats
    .select(
        pl.col.coef.round_sig_figs(2)
        .cast(str)
        .str.strip_chars_end("0")
        .alias(r"$\beta_{norm}$"),
        pl.col.t.round_sig_figs(2).cast(str).str.strip_chars_end("0").alias("T"),
        pl.col.df.round(1).cast(str),
        sig_format.alias(r"$P_{corr}$"),
        index=pl.col.index.replace(
            {
                "hemisphereR": r"$\sim Hem_{right}$",
                "groupPatient": r"$\sim G_{patient}$",
                "hemisphereR:groupPatient": r"$\sim Hem_{right}:G_{patient}$",
            }
        ),
        param=pl.col.param.replace(
            {
                "ndi": "NDI",
                "fw": r"$v_{iso}$",
                "odi": "ODI",
                "thickness": "Thickness",
                "csf": "CSF",
                "L1": "AD",
            }
        ),
    )
    .to_pandas()
    .set_index(["param", "index"])
    .sort_index()
    .unstack()
    .reorder_levels([1, 0], axis=1)
    .sort_index(axis=1)
    .reindex([r"$\beta_{norm}$", "T", "df", r"$P_{corr}$"], axis=1, level=1)
    .reset_index()
    .rename(columns={"": "Param", "param": ""})
    .style.hide()
    .to_latex(
        column_format="rllllllllllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-surf-global",
        caption="Statistics for average cortical parameters modelled against group and hemisphere.",
    )
)

### ROI Analysis

In [None]:
from lib import polars_expr as ple


def get_ols(col, filter=None):
    if col == "group":
        contr = "group[T.Patient]"
    else:
        contr = col
    return ple.ols(
        f"dk ~ {col} + age + sex + csf",
        contr,
        filter=filter,
        columns="param",
        alternative=pl.when(pl.col.param.is_in(["fw"]))
        .then(pl.lit(1))
        .when(pl.col.param.is_in(["FA", "RD", "MD", "L1", "csf"]))
        .then(pl.lit(0))
        .otherwise(pl.lit(-1)),
    ).alias(f"{col}_stats")


def get_all_stats(df):
    return (
        df.group_by("lobe", "param", "hemisphere")
        .agg(
            get_ols("group"),
            # get_ols("PANSSP", pl.col.group == "Patient"),
            # get_ols("PANSSN", pl.col.group == "Patient"),
        )
        .melt(["param", "lobe", "hemisphere"], cs.matches(".*_stats"), "model", "stats")
        .with_columns(pl.col("model").str.split("_").list.first())
        .unnest("stats")
        .with_columns(
            pl.col("pval")
            .map_elements(
                lambda x: pl.Series(scs.false_discovery_control(x)),
                return_dtype=pl.List(pl.Float64),
            )
            .over("param", "model")
            .name.suffix("corr")
        )
    )


hcp_stats = get_all_stats(hcp_df.with_columns(
    csf=pl.col("dk")
    .filter(pl.col.param == "csf")
    .first()
    .over("subject", "lobe", "hemisphere")
))

In [None]:
(
    hcp_stats
    # .with_columns(parameter=pl.concat_str("score", "feature", separator="_"))
    .with_columns(pl.col.param.replace({"L1": "AD", "fw": "v_iso"}))
    .sort(["param", "lobe", "hemisphere", "model"])
    .write_excel("suppl/surface-roi.xlsx")
)


In [None]:
fig = plt.figure(figsize=(6,10), layout="constrained")
axs = fig.subplots(6, 2)
lobes = ["insula", "frontal", "cingulate", "temporal", "occipital", "parietal"]
for i, (y, x) in enumerate(np.ndindex(6, 2)):
    df = hcp_df if x == 0 else hcp_nonlog
    lm  = smf.ols("dk ~ group + age + sex", df.filter(lobe=lobes[y], param='fw').to_pandas()).fit()
    axs[y,x].hist(lm.resid)
    axs[y,x].set_title(lobes[y])


In [None]:
#| fig-cap: Cortical ROIs affected in early psychosis. P-values corrected for multiple
#|   comparisons across ROIs for each parameter using FDR. Left and right hemispheres
#|   are illustrated separately. For each hemisphere, data from healthy controls are
#|   shown in the left curve, patients in the right. Curves correspond to kernel density
#|   estimates with bandwidth determined using the Scott method [@scott2015multivariate].
#|   Dashed lines represent first, second, and third quartiles. A. ROIs with significantly
#|   higher $v_{iso}$ in both hemispheres. B. The insula has lower FA in the left hemisphere
#|   only. Statistics shown in @tbl-surf-roi.
#| label: fig-surf-roi
from statannotations.Annotator import Annotator

roi_pvalues = hcp_stats.filter(
    pl.col.pvalcorr < 0.05, pl.col.param.is_in(["fw", "FA"]), model="group"
)[["param", "lobe", "hemisphere", "pvalcorr"]]

_df = hcp_df.pivot(
    index=["subject", "lobe", "hemisphere", "group"], columns="param", values="dk"
)
fig = plt.figure(figsize=(7.48, 7), layout="constrained")
axs = fig.subplots(3, 2)
params = ["fw", "fw", "fw", "fw", "fw", "FA"]
lobes = ["frontal", "insula", "temporal", "cingulate", "occipital", "insula"]
labels = {
    "fw": r"$\log{v_{iso}}$",
    "FA": "FA",
}

for i, (y, x) in enumerate(np.ndindex(3, 2)):
    ax = axs[y, x]
    __df = _df.filter(lobe=lobes[i])
    sns.stripplot(
        __df,
        y=params[i],
        x="hemisphere",
        hue="group",
        dodge=True,
        legend=False,
        hue_order=["HC", "Patient"],
        order=["L", "R"],
        alpha=0.8,
        ax=ax,
    )
    sns.violinplot(
        __df,
        y=params[i],
        x="hemisphere",
        hue="group",
        split=True,
        legend=False,
        palette=([cm.tab10(0), cm.tab10(1)]),
        alpha=0.4,
        linecolor="#f0f0f0",
        linewidth=1,
        inner="quart",
        hue_order=["HC", "Patient"],
        order=["L", "R"],
        ax=ax,
        cut=0,
    )
    ax.set_title(lobes[i].capitalize())
    ax.set_xlabel("" if i < 4 else "Hemisphere")
    ax.set_ylabel(labels[params[i]])
    if (lett_ := {0: "A", 5: "B"}.get(i)) is not None:
        ax.text(-0.18, 1.03, lett_, **Styles.panel_label, transform=ax.transAxes)
    if i == 5:
        # ax.set_facecolor("white")
        fig.patches.append(
            plt.Rectangle(
                (-0.2, -0.25),
                1.3,
                1.4,
                transform=ax.transAxes,
                color="#e0e0e0",
                zorder=-1,
            )
        )
    hemis, pvals = roi_pvalues.filter(param=params[i], lobe=lobes[i])[
        ["hemisphere", "pvalcorr"]
    ]
    annot = Annotator(
        ax,
        data=__df.to_pandas(),
        y=params[i],
        x="hemisphere",
        hue="group",
        order=["L", "R"],
        pairs=[((h, "HC"), (h, "Patient")) for h in hemis],
        verbose=False,
    )
    annot.hide_not_significant = True
    annot.configure(test=None).set_pvalues(pvals).annotate()

add_legend(fig, ["HC", "Patient"], cmap=cm.tab10, fontsize=10, bbox_to_anchor=(0.6, 0.73))
None

---
label: fig-surf-roi
---
Cortical ROIs affected in early psychosis. P-values corrected for multiple comparisons across ROIs for each parameter using FDR. Left and right hemispheres are illustrated separately. For each hemisphere, data from healthy controls are shown in the left curve, patients in the right. Curves correspond to kernel density estimates with bandwidth determined using the Scott method [@scott2015multivariate]. Dashed lines represent first, second, and third quartiles. A. ROIs with significantly higher $\nu_{iso}$ in both hemispheres. B. The insula has lower FA in the left hemisphere only. Statistics shown in @tbl-surf-roi.

In [None]:
import textwrap

p_format = (
    pl.when(pl.col("pvalcorr") < 0.001)
    .then(pl.lit("< .001"))
    .otherwise(
        pl.col("pvalcorr")
        .round(3)
        .round_sig_figs(2)
        .cast(str)
        .str.strip_chars_start("0")
    )
)
sig_format = (
    pl.when(pl.col("pvalcorr") < 0.05)
    .then(pl.format(r"\textbf{{}}", p_format))
    .otherwise(p_format)
)


def format_stats_table(df, label, caption):
    df_resid = df["df_resid"].cast(int)[0]
    return (
        df.select(
            pl.col.param.replace({"thickness": "Thickness", "fw": "$v_{iso}$"}).alias(
                "Param"
            ),
            pl.format("{}-{}", pl.col.hemisphere, pl.col.lobe.str.to_titlecase()).alias(
                "Lobe"
            ),
            pl.col.beta.round_sig_figs(2)
            .cast(str)
            .str.strip_chars_end("0")
            .alias(r"$\beta_{norm}$"),
            pl.col.statistic.round_sig_figs(2)
            .cast(str)
            .str.strip_chars_end("0")
            .alias(f"$T({df_resid})$"),
            sig_format.alias("$P_{corr}$"),
        )
        .to_pandas()
        .set_index(["Param", "Lobe"])
        .sort_index()
        .style.to_latex(
            column_format="rrlll",
            hrules=True,
            multicol_align="c",
            multirow_align="t",
            convert_css=True,
            label=label,
            caption=caption,
        )
    )


print(
    format_stats_table(
        hcp_stats.filter(pl.col.pvalcorr < 0.05, model="group"),
        "tbl-surf-roi",
        textwrap.dedent(
            """
            ROIs with significant group effects.
            """
        ),
    )
)

### Calculate Vertex-wise stats

In [None]:

ds = (
    hcp_smooth.to_dataset(name="data")
    .merge(
        hcp.metadata[["subject", "group", "age", "sex", "PANSSP", "PANSSN"]]
        .to_pandas()
        .set_index("subject"),
        join="inner",
    )
)
ds_pt = ds.groupby("group")["Patient"]
# ds_ = ds_.sel(subject=~np.isnan(ds_["PANSS-N"]))

surface = ds["data"].transpose("desc", "smoothing", "subject", "vertex")
df = ds[["group", "age", "sex", "PANSSP", "PANSSN"]].to_dataframe()
df_pt = ds_pt[["group", "age", "sex", "PANSSP", "PANSSN"]].to_dataframe()
term_group = FixedEffect(df["group"])
term_age = FixedEffect(df["age"])
term_sex = FixedEffect(df["sex"])
term_panssn = FixedEffect(df_pt["PANSSN"])
term_panssp = FixedEffect(df_pt["PANSSP"])

group_pt = np.asarray(df["group"] == "Patient").astype(int)
group_hc = np.asarray(df["group"] == "HC").astype(int)

models = {
    "group": {
        "model": term_group + term_age + term_sex,
        "contrast": group_hc - group_pt,
        "surface": surface,
    }

}

In [None]:
import dask.bag as db
from brainstat.stats.SLM import SLM
from brainstat.stats.terms import FixedEffect, MixedEffect

signs = {
    "fw": -1,
    "thickness": 1,
    "csf": -1,
}

def compute_stats(ds):
    smoothing = ds["smoothing"].item()
    suffix = ds["suffix"].item()
    sign = signs[suffix]
    model = models[ds["model"].item()]
    slm = SLM(
        model["model"],
        model["contrast"] * sign,
        mask=mask,
        surf="fslr32k",
        correction=["rft", "fdr"],
        cluster_threshold=0.01,
        two_tailed=False,
    )
    try:
        slm.fit(
            np.asanyarray(
                model["surface"].sel(desc=suffix, smoothing=smoothing).fillna(0)
            )
        )
    except np.linalg.LinAlgError:
        entries = {
            "C": None,
            "P": None,
            "Q": None,
            "t": None,
            "clusid": None,
        }
    except IndexError as err:
        raise Exception(f"{smoothing=} {suffix=} {sign=} {model=}") from err
    else:
        entries = {
            "C": slm.P["pval"]["C"],
            "P": slm.P["pval"]["P"],
            "Q": slm.Q,
            "t": slm.t,
            "clusid": slm.P["clusid"][0],
        }
    if entries["clusid"] is not None:
        entries["clusid"] = entries["clusid"][0]
    for key, val in entries.items():
        if val is None:
            ds[key].data = np.full(ds[key].shape, np.nan)
            continue
        ds[key].data = val.reshape(ds[key].shape)
    return ds


import dask.array as da

coords = {
    "smoothing": hcp_smooth["smoothing"].values[::2],
    "suffix": ["fw", "thickness", "csf"],
    "model": ["group"],
}
axes = {"vertex": 64984}
variables = {
    "C": {"dims": ["vertex"], "dtype": float},
    "P": {"dims": ["vertex"], "dtype": float},
    "Q": {"dims": ["vertex"], "dtype": float},
    "t": {"dims": ["vertex"], "dtype": float},
    "clusid": {"dims": ["vertex"], "dtype": int},
}
dims = list(coords.keys())
comp_vars = {}
for label, v in variables.items():
    shape = tuple(len(x) for x in coords.values()) + tuple(axes[d] for d in v["dims"])
    chunks = (1,) * len(coords) + tuple(axes[d] for d in v["dims"])

    comp_vars[label] = xr.DataArray(
        da.empty(shape=shape, chunks=chunks, dtype=v["dtype"]),
        dims=dims + v["dims"],
        coords=coords,
    )
template = xr.Dataset(comp_vars)
from dask.diagnostics import ProgressBar

stats = xr.map_blocks(compute_stats, template, template=template)
# stats.to_csv("stats.csv")

In [None]:
import dask

with dask.config.set(num_workers = 2):
    with ProgressBar():
        stats.load()
stats.to_netcdf("hcp_rft.nc")

### Plot Vertex-wise stats

In [None]:
stats = xr.open_dataset("hcp_rft.nc", chunks={})
stats = stats.assign(
    ma_t=stats["t"].where(stats["C"] < 0.05),
    ma_clusid=stats["clusid"].where(stats["C"] < 0.05),
)

In [None]:
coords = []
result = []
for suffix in ["fw", "thickness"]:
    x = stats.sel(smoothing=5, suffix=suffix, model="group")
    clusids = np.unique(x["clusid"])
    for i in clusids:
        vertices = np.nonzero(x["clusid"].data == i)
        if np.mean(x["C"][vertices]) <= 0.05:
            coords.append((suffix, i.astype(int)))
            result.append(
                hcp_smooth.sel(
                    smoothing=5,
                    desc=suffix,
                )
                .transpose("vertex", ...)[vertices]
                .mean("vertex")
            )

clusters = (
    xr.concat(result, dim="concat_dim")
    .assign_coords(dict(concat_dim=pd.Index(coords, name=("suffix", "clusid"))))
)

In [None]:
df = (
    clusters.to_dataframe(name="data")
    .drop(columns=["clusid", "suffix"])
    .reset_index()
    .pipe(pl.from_pandas)
    .join(hcp.metadata, on="subject", how="inner")
)

In [None]:
import scipy.sparse as scr


def get_vertex_adjacency(lh, rh):
    npoints = lh.GetNumberOfPoints() + rh.GetNumberOfPoints()
    conns = []

    for mesh, offset in zip((lh, rh), (0, lh.GetNumberOfPoints())):
        for i in range(mesh.GetNumberOfCells()):
            cell = mesh.GetCell(i)
            ids = cell.GetPointIds() + offset
            conns.extend(it.permutations(ids, 2))

    row, col = np.array(conns).T
    data = np.full_like(row, True, dtype=bool)
    return scr.coo_matrix((data, (row, col)), shape=(npoints, npoints))


mesh_adj = get_vertex_adjacency(lh, rh).tocsr()

In [None]:
def dilate(arr, iter=3):
    dilated = np.copy(arr)
    for _ in range(iter):
        for i in np.unique(dilated):
            if np.isnan(i):
                continue
            indices = np.unique(np.argwhere(mesh_adj[dilated == i])[:, 1])
            dilated[indices] = i
    return dilated

def smooth(arr):
    smoothed = np.copy(arr)
    for i in np.unique(smoothed):
        if np.isnan(i):
            continue
        mask = np.argwhere(smoothed == i).ravel()
        isolated = (np.sum(mesh_adj[np.ix_(mask, mask)], axis=0) < 4).getA1()
        smoothed[mask[isolated]] = np.nan
    return smoothed

clus_arr = stats.sel(smoothing=5, suffix="fw", model="group")["ma_clusid"].load().data
dilated = smooth(dilate(clus_arr))

In [None]:
n_lh = lh.n_points
statmap = stats.sel(smoothing=5, suffix="fw", model="group")["ma_t"].load().data
lh.append_array(statmap[:n_lh], at="p", name="t-val")
rh.append_array(statmap[n_lh:], at="p", name="t-val")
lh.append_array(dilated[:n_lh], at="p", name="outline")
rh.append_array(dilated[n_lh:], at="p", name="outline")
fw_trange = np.nanmin(statmap), np.nanmax(statmap)
fw_img = plot_surf(
    {
        "lh": lh,
        "rh": rh,
    },
    [
        ["lh", "lh", "rh", "rh"],
    ],
    # np.where(slm.t > 2, slm.t, np.nan),
    # slm.t,
    [
        [ (None, "outline", "t-val")],
        # [(None, "outline", "t-val")],
    ],
    [
        [(0, 60, 90), (-20, 180, 90), (0, -140, -90), (-40, 130, 90)],
    ],
    color_bar=True,
    cmap=[
        [
            ("autumn", cmaps.dark2_8,cmaps.matter_r)
        ],
    ],
    color_range=[
        [(None, (1, 8), fw_trange)],
    ],
    embed_nb=True,
    size=(1500*3, 250*3),
    zoom=1.85,
    nan_color=(0.7, 0.7, 0.7, 0),
    cb__labelTextProperty={"fontSize": 12},
    transparent_bg=False,
    return_plotter=True,
).to_numpy()

In [None]:
n_lh = lh.n_points
statmap = stats.sel(smoothing=5, suffix="thickness", model="group")["ma_t"].load().data
lh.append_array(statmap[:n_lh], at="p", name="t-val")
rh.append_array(statmap[n_lh:], at="p", name="t-val")
thi_trange = np.nanmin(statmap), np.nanmax(statmap)
thickness_img = plot_surf(
    {
        "lh": lh,
        "rh": rh,
    },
    [
        ["lh", "lh", "rh", "rh"],
    ],
    # np.where(slm.t > 2, slm.t, np.nan),
    # slm.t,
    [
        [(None, "t-val")],
    ],
    [
        ["lateral", "medial", "medial", "lateral"],
    ],
    color_bar=True,
    cmap=[
        [("autumn", cm.winter)],
    ],
    color_range=[
        [(None, thi_trange)],
    ],
    embed_nb=True,
    size=(2000 * 3, 400 * 3),
    zoom=1.55,
    nan_color=(0.7, 0.7, 0.7, 0),
    cb__labelTextProperty={"fontSize": 12},
    transparent_bg=False,
    return_plotter=True,
).to_numpy()
plt.imshow(thickness_img)

In [None]:
#| fig-cap: Cortical clusters with significantly higher $v_{iso}$ in patients. Clusters
#|   determined with random field theory with a P-value threshold of 0.01. Significance
#|   determined using one-tailed t-tests. A. Clusters with significantly higher $\log{v_{iso}}$.
#|   Colored outlines around the clusters correspond with scatter plots in B., and are
#|   not a part of the cluster. B. $\log{v_{iso}}$ in individual clusters. Dots show
#|   average $\log{v_{iso}}$ within clusters for individual subjects. Colors correspond
#|   with the cluster outlines in A.
#| label: fig-surf-rft
import seaborn.objects as so

fig = plt.figure(figsize=(7.48, 7.5), layout="constrained")
panela, panelb = fig.subfigures(2, 1, height_ratios=[1.5, 6])
panela.text(0, 1, "A", va="top", **Styles.panel_label)
# surf_axs[0].text(
#     -0.2, 0.75, "Left", **Styles.col_title, transform=surf_axs[0].transAxes
# )
# surf_axs[0].text(
#     -0.2, 0.25, "Right", **Styles.col_title, transform=surf_axs[0].transAxes
# )
# add_colorbar(*fw_trange, ax=surf_axs[1], cmap=cmaps.matter_r)
# surf_axs[1].set_ylabel("T-value")
# surf_axs[1].set_position(surf_axs[1].get_position().translated(0, -0.1))
ax, cax = panela.subplots(1, 2, width_ratios=[40, 1])
ax.imshow(fw_img)
ax.axis("off")
# ax.text(
#     -0.03, 0.5, "Thickness", **Styles.row_title, transform=ax.transAxes
# )
add_colorbar(*fw_trange, ax=cax, outline=False, cmap=cmaps.matter_r)
cax.set_ylabel("T-value")
panelb.text(0, 1.03, "B", **Styles.panel_label)

grid = panelb.add_gridspec(4, 2)
clusids = df.filter(suffix="fw")["clusid"].unique()
titles = [
    "L-Insula",
    "R-Insula",
    "L-Middle Temporal Gyrus",
    "L-Occipital Lobe",
    "R-Occipital Lobe",
]
for i, pos in enumerate(grid):
    ax = panelb.add_subplot(pos)
    comparison_plot(
        df.filter(pl.col.suffix == "fw", clusid=clusids[i]),
        x="group",
        y="data",
        alpha=0.7,
        color=cmaps.dark2_8(i),
        ax=ax,
    )
    ax.set_xlabel("" if i < 6 else "Group")
    ax.set_ylabel(r"$\log{v_{iso}}$" if i % 2 == 0 else "")
    # ax.set_title(titles[i], fontsize=14)

# ax, cax = panelb.subplots(1, 2, width_ratios=[40, 1])
# panelb.text(0, 1.03, "B", **Styles.panel_label)
# ax.imshow(thickness_img)
# ax.axis("off")
# ax.text(
#     -0.03, 0.5, "Thickness", **Styles.row_title, transform=ax.transAxes
# )
# add_colorbar(*thi_trange, ax=cax, cmap="winter")
# cax.set_ylabel("T-value")
None

---
label: fig-surf-rft
---
Cortical clusters with significantly higher $\nu_{iso}$ in patients. Clusters determined with random field theory with a P-value threshold of 0.01. Significance determined using one-tailed t-tests. A. Clusters with significantly higher $\log{\nu_{iso}}$. Colored outlines around the clusters correspond with scatter plots in B., and are not a part of the cluster. B. $\log{\nu_{iso}}$ in individual clusters. Dots show average $\log{\nu_{iso}}$ within clusters for individual subjects. Colors correspond with the cluster outlines in A.

In [None]:
#| fig-cap: Cortical clusters with significantly lower thickness in patients. Clusters
#|   were determined with random field theory with a P-value threshold of 0.01. One-tailed
#|   t-tests were used to test for decreased thickness.
#| label: fig-surf-thickness-rft
import seaborn.objects as so

fig = plt.figure(figsize=(7.48, 1.5), layout="constrained")
# surf_axs[0].text(
#     -0.2, 0.75, "Left", **Styles.col_title, transform=surf_axs[0].transAxes
# )
# surf_axs[0].text(
#     -0.2, 0.25, "Right", **Styles.col_title, transform=surf_axs[0].transAxes
# )
# add_colorbar(*fw_trange, ax=surf_axs[1], cmap=cmaps.matter_r)
# surf_axs[1].set_ylabel("T-value")
# surf_axs[1].set_position(surf_axs[1].get_position().translated(0, -0.1))
ax, cax = fig.subplots(1, 2, width_ratios=[40, 1])
ax.imshow(thickness_img)
ax.axis("off")
# ax.text(
#     -0.03, 0.5, "Thickness", **Styles.row_title, transform=ax.transAxes
# )
add_colorbar(*thi_trange, ax=cax, outline=False, cmap=cm.winter)
cax.set_ylabel("T-value")
None

---
label: fig-surf-thickness-rft
---
Cortical clusters with significantly lower thickness in patients. Clusters were determined with random field theory with a P-value threshold of 0.01. One-tailed t-tests were used to test for decreased thickness.

## White matter

### Connectivity

In [None]:
hcp_modelling = BidsLayout("../HCPpsych/derivatives/noddi-models")

In [None]:
import h5py


@layout_map(
    dims={"src": atlases.bn246["Label ID"], "dest": atlases.bn246["Label ID"]},
    dtype=int,
)
def load_h5_nbs(file):
    with h5py.File(file, "r") as f:
        try:
            networks = f["nbs/con_mat"][:]
            # sizes = np.sum(networks.reshape(-1, networks.shape[-1]), axis=0)
            # labels = np.argsort(sizes)[::-1] + 1
            # return np.sum(networks * labels, axis=-1)
            sig = np.sum(networks, axis=-1)
        except IndexError:
            sig = 0
        print(f["nbs"].attrs["pval"])
        return f["nbs/test_stat"][:] * sig


descs = [
    "".join(x)
    for x in it.product(["fw", "ndi", "odi"], ["phenotype", "panssp", "panssn"])
]
hcp_nbs = load_h5_nbs(
    hcp_modelling.get(suffix="nbs", label=["pos", "neg"], desc=descs), ["desc"]
)

### White matter ROIs

#### Sampling

In [None]:
def get_wm_from_rois(path, wildcards, atlases):
    labels = list(
        it.chain.from_iterable(
            zip(it.repeat(i), range(np.max(atlas).astype(int) + 1))
            for i, atlas in enumerate(atlases)
        )
    )
    nlabels = len(labels)

    @layout_map(parallel=True, dims={"roi": nlabels}, dtype=float)
    def inner(path):
        data = nb.load(path).get_fdata()
        result = np.empty((nlabels,))
        try:
            for i in range(0, nlabels):
                atlas_ix, ix = labels[i]
                if ix == 0:
                    result[i] = 0
                    continue
                atlas = atlases[atlas_ix]

                result[i] = np.mean(data[atlas == ix])
            return result
        except:
            print(path)
            raise

    return inner(path, wildcards)


atlas_md = pl.read_csv("atlas-study_labels.csv")


def get_group_mapping():
    groups = atlas_md.filter(group="core")
    group_id, atlas_id = atlas_md.join(
        groups[["name"]], left_on="group", right_on="name"
    ).with_columns(
        pl.col.group.replace(
            dict(zip(*groups[["name", "atlas_id"]])), return_dtype=pl.Int32
        )
    )[["group", "atlas_id"]]
    mapping = np.zeros(((atlas_id.max() + 1),))
    mapping[atlas_id] = group_id
    return mapping


def get_atlases(layout, dims, jhu_atlas, lobe_atlas):
    @layout_map(parallel=True, dims=dims, dtype=float)
    def get_skeleton(path):
        return nb.load(path).get_fdata()

    mean_skeleton = get_skeleton(layout, ["subject"]).mean(["subject"]) > 0
    jhu_atlas, lobe_atlas = (
        np.where(mean_skeleton, nb.load(atlas).get_fdata(), 0)
        for atlas in (jhu_atlas, lobe_atlas)
    )
    core_groups = get_group_mapping().astype(int)
    lobe_mask = np.where(jhu_atlas == 0, lobe_atlas, 0)
    return {
        "lobe_mask": lobe_mask,
        "jhu_atlas": jhu_atlas,
        "global_mask": (jhu_atlas > 0) | (lobe_mask > 0),
        "core-periph": (jhu_atlas > 0).astype(int) + ((lobe_mask > 0).astype(int)) * 2,
        "core_group_mask": core_groups[jhu_atlas.astype(int)],
    }


def run_roi_sampling(layout, jhu_atlas, lobe_atlas, skeleton_dims):
    atlases = get_atlases(
        layout.get(suffix="skeletonized", desc="ndi"),
        dims=skeleton_dims,
        jhu_atlas=jhu_atlas,
        lobe_atlas=lobe_atlas,
    )
    return get_wm_from_rois(
        layout.get(suffix="skeletonized", desc=["ndi", "fw", "odi", "logfw"]),
        ["subject", "desc"],
        atlases=list(atlases.values()),
    )


def sample_hemispheres(layout, atlas, skeleton_dims):
    mean_skeleton = (
        get_mean_skeleton(
            layout.get(suffix="skeletonized", desc="ndi"),
            ["subject"],
            dims=skeleton_dims,
        ).mean(["subject"])
        > 0
    )
    atlas = np.where(mean_skeleton, nb.load(atlas).get_fdata(), 0)
    atlas[(atlas < 7) & (atlas % 2 == 1)] = 1
    atlas[(atlas % 2 == 0) & (atlas != 0)] = 2
    atlas[atlas > 2] = 0
    return get_wm_from_rois(
        layout.get(suffix="skeletonized", desc=["ndi", "fw", "odi"]),
        ["subject", "desc"],
        atlases=[atlas],
    )


from dask.diagnostics import ProgressBar

hcp_wm_sampled = run_roi_sampling(
    hcp.layout.get(suffix="skeletonized", datatype=False),
    jhu_atlas="../HCPpsych/derivatives/atlases/atlas.nii.gz",
    lobe_atlas="../HCPpsych/derivatives/atlases/lobe-atlas.nii.gz",
    skeleton_dims={"x": 121, "y": 145, "z": 121},
)

# hcp_hemi_sampled = sample_hemispheres(
#     hcp.layout.get(suffix="skeletonized", datatype=False),
#     atlas="../HCPpsych/derivatives/atlases/hemi-atlas.nii.gz",
#     skeleton_dims={"x": 121, "y": 145, "z": 121},
# )

In [None]:
jhu_atlas = "../HCPpsych/derivatives/atlases/atlas.nii.gz"
lobe_atlas = "../HCPpsych/derivatives/atlases/lobe-atlas.nii.gz"
atlas = "../HCPpsych/derivatives/atlases/hemi-atlas.nii.gz"

skeleton_dims = {"x": 121, "y": 145, "z": 121}
atlases = get_atlases(
    hcp.layout.get(suffix="skeletonized", desc="ndi"),
    dims=skeleton_dims,
    jhu_atlas=jhu_atlas,
    lobe_atlas=lobe_atlas,
)


In [None]:
from nilearn import plotting
bg = "../HCPpsych/derivatives/tpl-FA/tpl-study/tpl-study_FA.nii.gz"
img = nb.load(bg)
param_map = nb.Nifti1Image(
    atlases["core_group_mask"],
    img.affine,
    img.header,
)
plotting.plot_roi(
    param_map,
    bg,
    cut_coords=np.r_[-10:50:7j],
    resampling_interpolation="nearest",
    display_mode="z",
    annotate=False,
    colorbar=False,
)

In [None]:
hcp_wm_sampled

In [None]:

with ProgressBar():
    hcp_wm_sampled.to_netcdf("hcp_wm_sampled.nc")
    # hcp_hemi_sampled.to_netcdf("hcp_hemi_sampled.nc")

#### Analysis

In [None]:
hcp_hemi_sampled = xr.open_dataarray("hcp_hemi_sampled.nc", chunks={}).drop_sel(
    subject="1032"
)

In [None]:
hcp_hemi_df = (
    hcp_hemi_sampled.to_dataframe(name="data")
    .reset_index()
    .pipe(pl.from_pandas)
    .with_columns(
        pl.col.roi.replace(
            {
                1: "L",
                2: "R",
            },
            return_dtype=pl.String,
        )
    )
    .filter(pl.col.roi.is_in(["L", "R"]))
    .join(hcp.metadata, on="subject", how="inner")
    .pivot(
        index=["subject", "roi", "group", "PANSSP", "PANSSN", "age", "sex"],
        columns="desc",
        values="data",
    )
)

lm = smf.mixedlm(
    "scale(ndi) ~ scale(age) + sex + group*roi",
    groups="subject",
    data=hcp_hemi_df.to_pandas(),
).fit()
print(lm.summary())
lm = smf.ols(
    "scale(fw) ~ scale(age) + sex + group*roi",
    # groups="subject",
    data=hcp_hemi_df.to_pandas()
).fit()
print(lm.summary())

In [None]:
hcp_wm_sampled = xr.open_dataarray("hcp_wm_sampled.nc", chunks={}).drop_sel(
    subject=[*dropped_subs, "1032"]
)

In [None]:
atlas_md = pl.read_csv("atlas-study_labels.csv")

atlas_filters = [
    ~pl.col("label").is_in(
        [
            "Med",
            "Po",
            "Mb",
            "FTS",
        ]
    ),
    pl.col("group") != "cerebellar"
]

In [None]:
group_indices = dict(
    zip(*atlas_md.filter(pl.col("group").is_in(["core", "global"]))[["name", "index"]])
)


def get_index(group: str):
    return pl.lit(group_indices[group], dtype=pl.Int64)


def prepare_wm_rois(df, index):
    df = df.join(atlas_md.rename({"index": "roi"}), on="roi").filter(*atlas_filters)
    return (
        df.join(atlas_md.rename({"index": "roi"}), on="roi")
        .filter(*atlas_filters)
        .group_by(*index, "label")
        .agg(
            pl.col(
                "region",
                "hierarchy",
            ).first(),
            pl.col("data").mean(),
        )
    )

In [None]:
def all_sessions(da, hx):
    return prepare_wm_rois(
        da.to_dataset(name="data")
        .to_dataframe()
        .dropna()
        .reset_index()
        .pipe(pl.from_pandas),
        ["subject", "desc"],
    ).join(hx[["subject", "group", "age", "sex", "PANSSN", "PANSSP"]], on=["subject"])


hcp_wm_df = all_sessions(
    hcp_wm_sampled,
    hcp.metadata,
)

In [None]:
plt.hist(hcp_wm_df.filter(label="WM", desc="fw")["data"])

In [None]:
def get_wm_ols(col, filter=None):
    if col == "group":
        contr = "group[T.Patient]"
    else:
        contr = col
    return ple.ols(
        f"data ~ {col} + age + sex",
        contr,
        filter=filter,
        columns="desc",
        alternative=pl.when(pl.col.desc == "fw").then(pl.lit(1)).otherwise(pl.lit(-1)),
    ).alias(f"{col}_stats")


def get_wm_stats(df):
    return (
        df.group_by("label", "desc")
        .agg(
            get_wm_ols("group"),
            get_wm_ols("PANSSP", pl.col.group == "Patient"),
            get_wm_ols("PANSSN", pl.col.group == "Patient"),
            pl.first("hierarchy"),
        )
        .melt(["desc", "label", "hierarchy"], cs.matches(".*_stats"), "model", "stats")
        .with_columns(pl.col("model").str.split("_").list.first())
        .unnest("stats")
        .with_columns(
            pl.col("pval")
            .map_elements(
                lambda x: pl.Series(scs.false_discovery_control(x)),
                return_dtype=pl.List(pl.Float64),
            )
            .over("desc", "model", "hierarchy")
            .name.suffix("corr")
        )
    )


hcp_wm_stats = get_wm_stats(hcp_wm_df)

In [None]:
(
    hcp_wm_stats
    # .with_columns(parameter=pl.concat_str("score", "feature", separator="_"))
    .join(
        # atlas_md.rename({"index": "roi"}),
        atlas_md.group_by("label").agg(pl.first("region")),
        on="label",
    )
    .filter(pl.col.desc != "fw")
    .with_columns(pl.col.desc.replace({"logfw": "v_iso"}))
    .sort(["desc", "label", "model"])
    .write_excel("suppl/white-matter-stats.xlsx")
)


### Plot WM changes

In [None]:
#| fig-cap: No effect of diagnosis in white matter &NODDI. Average parameter values are
#|   represented split by core and peripheral white matter. For each split, healthy controls
#|   are shown in the left curve, patients in the right. Curves correspond to kernel
#|   density estimates with bandwidth determined using the Scott method [@scott2015multivariate].
#|   Dashed lines represent first, second, and third quartiles. Diagnosis of early-stage
#|   psychosis does not significantly affect global NDI, ODI, or $v_{iso}$.
#| label: fig-wm
side_title = dict(
    x=-0.1,
    y=0.5,
    rotation="vertical",
    rotation_mode="anchor",
    size=10,
    ha="center",
    va="bottom",
    color=Styles.Colors.dark[0],
)
_df = hcp_wm_df.with_columns(
    pl.col.label.replace(
        {"PWM": "Peripheral", "CWM": "Core"}
    )
)

fig = plt.figure(figsize=(7.48, 2.5), layout="constrained")

axs = fig.subplots(1, 3)
params = ["ndi", "odi", "logfw"]
labels = {
    "logfw": r"$\log{v_{iso}}$",
    "fw": "$f_{fw}$",
    "ndi": "NDI",
    "odi": "ODI",
    "phenotype": "Phenotype",
    "panssp": "PANSS30-P",
    "panssn": "PANSS30-N",
}
for i in range(3):
    __df = _df.filter(pl.col.label.is_in(["Peripheral", "Core"]), desc=params[i])
    ax = axs[i]
    sns.stripplot(
        __df,
        y="data",
        x="label",
        hue="group",
        dodge=True,
        legend=False,
        hue_order=["HC", "Patient"],
        order=["Core", "Peripheral"],
        alpha=0.8,
        ax=ax,
    )
    sns.violinplot(
        __df,
        y="data",
        x="label",
        hue="group",
        split=True,
        legend=False,
        palette=([cm.tab10(0), cm.tab10(1)]),
        alpha=0.4,
        linecolor="#f0f0f0",
        linewidth=1,
        inner="quart",
        hue_order=["HC", "Patient"],
        order=["Core", "Peripheral"],
        cut=0,
        ax=ax,
    )
    ax.set_xlabel("Tracts")
    ax.set_ylabel(labels[params[i]])
add_legend(
    fig,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=10,
    loc="lower center",
    bbox_to_anchor=(0.63, 0.2),
)
None

---
label: fig-wm
---
No effect of diagnosis in white matter &NODDI. Average parameter values are represented split by core and peripheral white matter. For each split, healthy controls are shown in the left curve, patients in the right. Curves correspond to kernel density estimates with bandwidth determined using the Scott method [@scott2015multivariate]. Dashed lines represent first, second, and third quartiles. Diagnosis of early-stage psychosis does not significantly affect global NDI, ODI, or $\nu_{iso}$.

## Partial Voluming Investigation

In [None]:
stats = xr.open_dataset("hcp_rft.nc")

fw_t = stats.sel(smoothing=5, suffix="fw", model="group")["t"].data
csf_t = stats.sel(smoothing=5, suffix="csf", model="group")["t"].data
thickness_t = stats.sel(smoothing=5, suffix="thickness", model="group")["t"].data
thickness = (hcp_smooth.sel(desc="thickness", smoothing=5).load().data).mean(axis=0)
fw = (hcp_smooth.sel(desc="fw", smoothing=5).load().data).mean(axis=0)


In [None]:
n_lh = lh.n_points
data = fw_t
lh.append_array(data[:n_lh], at="p", name="t-val")
rh.append_array(data[n_lh:], at="p", name="t-val")
fw_trange = np.nanmin(data), np.nanmax(data)
fw_t_img = plot_surf(
    {"lh": lh, "rh": rh},
    [
        ["lh"],
        ["lh"],
        ["rh"],
        ["rh"],
    ],
    # np.where(slm.t > 2, slm.t, np.nan),
    # slm.t,
    [
        (None, "t-val"),
    ],
    [
        ["lateral"],
        ["medial"],
        ["lateral"],
        ["medial"],
    ],
    color_bar=True,
    cmap=[
        cm.viridis,
    ],
    # color_range=[
    #     [tuple(fw_trange)],
    #     [tuple(fw_trange)],
    #     [tuple(fw_trange)],
    #     [tuple(fw_trange)],
    # ],
    embed_nb=True,
    size=(250 * 3, 700 * 3),
    zoom=1.6,
    nan_color=(0.7, 0.7, 0.7, 0),
    cb__labelTextProperty={"fontSize": 12},
    transparent_bg=False,
    return_plotter=True,
).to_numpy()

In [None]:
n_lh = lh.n_points
def get_2by2(data):
    lh.append_array(data[:n_lh], at="p", name="t-val")
    rh.append_array(data[n_lh:], at="p", name="t-val")
    return plot_surf(
        {"lh": lh, "rh": rh},
        [
            ["lh", "lh"],
            ["rh", "rh"],
        ],
        # np.where(slm.t > 2, slm.t, np.nan),
        # slm.t,
        [
            [(None, "t-val"), (None, "t-val")],
            [(None, "t-val"), (None, "t-val")],
        ],
        [
            ["lateral", "medial"],
            ["lateral", "medial"],
        ],
        color_bar=True,
        cmap=[
            [cm.viridis, cm.viridis],
            [cm.viridis, cm.viridis],
        ],
        # color_range=[
        #     [tuple(fw_trange)],
        #     [tuple(fw_trange)],
        #     [tuple(fw_trange)],
        #     [tuple(fw_trange)],
        # ],
        embed_nb=True,
        size=(400 * 3, 300 * 3),
        zoom=1.6,
        nan_color=(0.7, 0.7, 0.7, 0),
        cb__labelTextProperty={"fontSize": 12},
        transparent_bg=False,
        return_plotter=True,
    ).to_numpy()
data = csf_t
csf_trange = np.nanmin(data), np.nanmax(data)
csf_t_img = get_2by2(data)

data = thickness_t
thickness_trange = np.nanmin(data), np.nanmax(data)
thickness_t_img = get_2by2(data)

data = thickness
thickness_range = np.nanmin(data), np.nanmax(data)
thickness_img = get_2by2(data)

data = fw
fw_range = np.nanmin(data), np.nanmax(data)
fw_img = get_2by2(data)


In [None]:
from brainspace.null_models import SpinPermutations

n_rand = 1000

sp = SpinPermutations(n_rep=n_rand, random_state=0)
sp.fit(lh_sphere, rh_sphere)

In [None]:
for arr in [fw, thickness, fw_t, thickness_t, csf_t]:
    arr[~mask] = np.nan

def do_rotation(arr):
    return np.hstack(sp.randomize(arr[:lh.n_points], arr[lh.n_points:]))

fw_rot = do_rotation(fw)
fw_t_rot = do_rotation(fw_t)

In [None]:
from matplotlib import pyplot as plt
from scipy.stats import spearmanr


def spin_test(arr1, rot1, arr2):
    r_spin = np.empty(n_rand)
    r_obs, pv_obs = spearmanr(arr1[mask], arr2[mask])

    # Compute perm pval
    for i, perm in enumerate(rot1):
        mask_rot = mask & ~np.isnan(perm)  # Remove midline
        r_spin[i] = spearmanr(perm[mask_rot], arr2[mask_rot])[0]
    return r_obs, r_spin


fw_thickness_corr, fw_thickness_spin = spin_test(fw, fw_rot, thickness)
fw_thickness_t_corr, fw_thickness_t_spin = spin_test(fw_t, fw_t_rot, thickness_t)
fw_csf_t_corr, fw_csf_t_spin = spin_test(fw_t, fw_t_rot, csf_t)


In [None]:
fw, thickness, csf = (
    hcp_glob_df.group_by("subject", "param")
    .agg(pl.mean("dk"))
    .pivot(
        index=cs.exclude("param", "dk"),
        columns="param",
        values="dk",
    )[["fw", "thickness", "csf"]]
)
print(scs.pearsonr(fw, thickness))
print(scs.pearsonr(fw, csf))
print(len(fw))

In [None]:
#| fig-cap: $v_{iso}$ findings are distinct from partial volume effects. A. Significant
#|   correlations observed across subjects between global averages of $v_{iso}$ and
#|   cortical thickness ($r(105)=-0.57;P<.001$) and CSF volume fraction ($r(105)=0.31;p=.001$).
#|   Healthy controls and patients are plotted separately, but Pearson's R is calculated
#|   over the combined sample. B.  Thickness and $v_{iso}$ maps were made by averaging
#|   across subjects. T-value maps were made using a vertex-wise t-tests between patients
#|   and healthy controls, using age and sex as covariates. The center column shows the
#|   results of spin tests between parameter maps, testing for spatial correlation. 1000
#|   replicates were performed per test. The correlation between parameters was tested
#|   using Spearman's rank test. Red curves show kernel density estimate of null distribution.
#|   The empirical value is represented by the vertical line. P-value shows result of
#|   two-tailed test. The spatial distribution of thickness and $v_{iso}$ are significantly
#|   spatially correlated, but not the T-maps that represent diagnostic effect for $v_{iso}$
#|   and thickness, or for $v_{iso}$ and CSF volume fraction.
#| label: fig-pve
def perms_plot(obs, dist, *, ax):
    sns.kdeplot(dist, color="red", fill="red", alpha=0.1, edgecolor="#ff000088", ax=ax)
    ax.vlines(x=obs, ymin=0, ymax=2, color="#222222")
    plt.text(
        obs,
        2,
        "P = {}".format(np.mean(np.abs(dist) >= np.abs(obs))),
        va="bottom",
        ha="center",
    )


fig = plt.figure(figsize=(7.48, 7.5), layout="constrained")
panela, panelb = fig.subfigures(2, 1, height_ratios=[2.5, 5])

panela.text(0.01, 1, "A", va="top", **Styles.panel_label)
axs = panela.subplots(1, 2)
params = ["thickness", "csf"]
labels = ["Thickness", "CSF"]

for i in range(2):
    (
        so.Plot(
            hcp_glob_df.group_by("subject", "param")
            .agg(pl.mean("dk"), pl.first("group"))
            .pivot(
                index=cs.exclude("param", "dk"),
                columns="param",
                values="dk",
            ),
            x=params[i],
            y="fw",
            # color="group",
        )
        .add(so.Dot(alpha=0.9, pointsize=2.5, edgealpha=0), color="group", legend=False)
        .add(so.Line(), so.PolyFit(1), color="group", linestyle="group", legend=False)
        .add(
            so.Band(),
            PolyCI(nsims=1000),
            color="group",
            legend=False,
        )
        .add(so.Text(halign="right"), PearsonrAnnot("upper right"))
        .label(x=labels[i], y=r"$v_{iso}$", title="Global average")
        .on(axs[i])
        .plot()
    )
add_legend(
    panela,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=10,
    size=4,
    bbox_to_anchor=(1, 0.5),
)

panelb.text(0.01, 1, "B", va="top", **Styles.panel_label)
grid = panelb.add_gridspec(3, 3)

ax = panelb.add_subplot(grid[0, 1])
perms_plot(fw_thickness_corr, fw_thickness_spin, ax=ax)
ax.set_title(r"$v_{iso} \sim \text{Thickness}$")

ax = panelb.add_subplot(grid[1, 1])
perms_plot(fw_thickness_t_corr, fw_thickness_t_spin, ax=ax)
ax.set_title(r"$T_{v_{iso}} \sim T_{\text{Thickness}}$")

ax = panelb.add_subplot(grid[2, 1])
perms_plot(fw_csf_t_corr, fw_csf_t_spin, ax=ax)
ax.set_title(r"$T_{v_{iso}} \sim T_{\text{CSF}}$")

ax = panelb.add_subplot(grid[0, 2])
ax.imshow(thickness_img)
ax.axis("off")
ax.set_title("Thickness", fontweight="500")

ax = panelb.add_subplot(grid[0, 0])
ax.imshow(fw_img)
ax.axis("off")
ax.set_title(r"$v_{iso}$")

ax = panelb.add_subplot(grid[1:3, 0])
ax.imshow(fw_t_img)
ax.axis("off")
ax.set_title(r"$T_{v_{iso}}$")

ax = panelb.add_subplot(grid[1, 2])
ax.imshow(thickness_t_img)
ax.axis("off")
ax.set_title(r"$T_{\text{Thickness}}$")

ax = panelb.add_subplot(grid[2, 2])
ax.imshow(csf_t_img)
ax.axis("off")
ax.set_title(r"$T_{\text{CSF}}$")
None