In [None]:
import matplotlib.pyplot as plt

%matplotlib inline
plt.subplots()
# hack to remove hide globally installed libraries, which are the wrong R version
from rpy2 import robjects as ro

ro.r(".libPaths('/local/scratch/gt/lib/R/library')")

In [None]:
import itertools as it
import re
from pathlib import Path
import textwrap

import colormaps as cmaps
import graph_tool.all as gt
import more_itertools as itx
import nibabel as nb
import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as cs
import scipy.stats as scs
import seaborn as sns
import seaborn.objects as so
import statsmodels.api as sm
import statsmodels.formula.api as smf
import templateflow.api as tflow
import tqdm
import xarray as xr
from dask.diagnostics import ProgressBar
from matplotlib import font_manager
from matplotlib.ticker import FormatStrFormatter, FuncFormatter
from matplotlib import cm
from nilearn import plotting
from numpy.polynomial import Polynomial
import functools as ft
import operator as op
from rsbids import BidsLayout

from lib import atlases
from lib.bidsarray import layout_map
from lib.dataset import Dataset
from lib.plotting import (
    add_colorbar,
    annotate_axes,
    comparison_plot,
    fig_to_numpy,
    move_legend_fig_to_ax,
    plot_hierachical_connectome,
    add_legend,
)
from lib.seaborn_stats import Lme4CI, MLEFit, PearsonrAnnot, PolyCI
from lib.utils import concat_product
from lib.demographics import DemographicTable
from styles import styles as Styles

%load_ext autoreload
%autoreload 2
%matplotlib inline
# plt.switch_backend("cairo")
plt.style.use("styles/manuscript.mplstyle")
so.Plot.config.theme.update(plt.rcParams)
font_dirs = [Path.home() / ".fonts"]
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

In [None]:
import rpy2
import rpy2.ipython.html
from rpy2 import robjects as ro

rpy2.ipython.html.init_printing()

from rpy2.robjects.packages import importr

rutils = importr("utils")
rbase = importr("base")
lme4 = importr("lme4")
rstats = importr("stats")
pbkrtest = importr("pbkrtest")
lmertest = importr("lmerTest")
%load_ext rpy2.ipython

### Initialize Datasets

In [None]:
jhp = (
    Dataset(".jhp.layout", "jhp", group_label="group")
    .add_phenotypes("jhp_metadata.yaml")
    .filter(
        ~pl.col("dx").is_in(["BPADI", "MDD", "Substance_induced", "Others"]),
        # ~pl.col("subject").is_in(["1018", "1026", "1047", "2041"]),
    )
)


topsy = (
    Dataset(".topsy.layout", "topsy")
    .add_phenotypes("topsy_metadata.yaml")
    .filter(pl.col("ddx").is_in([1, 2, 3, 9, 10]))
)


JHP_SESSIONS = {1: "Baseline", 2: "1yr", 3: "2yr", 4: "3yr"}
TOPSY_SESSIONS = {1: "Baseline", 2: "≈6 mo", 3: "1-2 yr"}
num_sessions = pl.col("session").unique().len().over("subject")

### Prepare clinical data

In [None]:
def do_fit(struct):
    struct = struct.struct.unnest().filter(pl.all_horizontal(~pl.col("*").is_null()))
    if struct["x"].len() == 1:
        coefs = [0.0, 0.0]
    else:
        coefs = Polynomial.fit(struct["x"], struct["y"], 1).convert().coef
        if coefs.shape[0] == 1:
            coefs = [coefs[0], 0.0]
    return {
        "intercept": (coefs[0]),
        "slope": (coefs[1]),
    }


def sel_fit(col):
    return (
        pl.when(pl.len() > 1)
        .then(
            pl.struct(pl.col("session").cast(int) - 1, col)
            .struct.rename_fields(["x", "y"])
            .map_elements(
                do_fit,
                return_dtype=pl.Struct({"intercept": pl.Float64, "slope": pl.Float64}),
            )
            .struct.rename_fields([f"{col}_intercept", f"{col}_slope"])
        )
        .alias(f"{col}_fit")
    )


def sel_recovery(col):
    return (pl.col(col).filter(pl.col("session") == "2").first() < 4).name.suffix(
        "_recovery"
    )


def sel_start_zero(col):
    return (pl.col(col).filter(pl.col("session") == "1").first() < 4).name.suffix(
        "_nostart"
    )


jhp_hx = (
    jhp.metadata.filter(pl.col("group") == "Patient")
    .group_by("subject")
    .agg(
        pl.len().alias("num_sessions"),
        pl.first("sex"),
        pl.mean("age"),
        pl.mean("sans", "saps").name.suffix("_mean"),
        pl.col("saps", "sans")
        .filter(pl.col("session") == "1")
        .first()
        .name.suffix("_baseline"),
        sel_fit("sans"),
        sel_fit("saps"),
    )
    .unnest("sans_fit", "saps_fit")
)

topsy_hx = (
    topsy.metadata.filter(pl.col("group") == "FEP", pl.col("session").cast(int) < 3)
    .group_by("subject")
    .agg(
        pl.len().alias("num_sessions"),
        pl.first("sex"),
        pl.mean("age"),
        pl.mean("PANSSP", "PANSSN", "SOFAS").name.suffix("_mean"),
        pl.col("PANSSP", "PANSSN", "SOFAS")
        .filter(pl.col("session") == "1")
        .first()
        .name.suffix("_baseline"),
        sel_fit("PANSSP"),
        sel_fit("PANSSN"),
        sel_fit("SOFAS"),
        sel_recovery("PANSSP"),
        sel_recovery("PANSSN"),
        sel_start_zero("PANSSN"),
    )
    .unnest(cs.matches(".*_fit"))
)

In [None]:
jhp_hx.unique("subject").filter(pl.col.num_sessions > 1)

In [None]:
jhp.metadata.filter(pl.col.num_sessions > 1).group_by("subject").agg(
    pl.first("group", "dx")
).group_by("dx").len()

In [None]:
table = (
    topsy_hx.filter(pl.col("num_sessions") > 1)
    .group_by("sex", "PANSSN_recovery")
    .len()
    .pivot(index="sex", columns="PANSSN_recovery", values="len")
    .drop("sex")
    .to_numpy()
)
scs.chi2_contingency(table)

lm = smf.ols(
    "age ~ PANSSN_recovery",
    data=topsy_hx.filter(pl.col("num_sessions") > 1).with_columns(
        pl.col("PANSSN_recovery").cast(int)
    ),
).fit()
lm.model.exog

#### Demographics and clinical scores

In [None]:
fig = plt.figure(figsize=(6, 10), layout="constrained")
axs = fig.subplots(6, 2)
for i, x in enumerate(("PANSSP_recovery", "PANSSN_recovery")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        y="age",
        x=x,
        ax=axs[0, i],
        xlabel="Recovery?",
        ylabel="Age",
    )
for i, x in enumerate(("PANSSP_recovery", "PANSSN_recovery")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            color="sex",
            x=x,
        )
        .add(so.Bar(), so.Hist())
        .on(axs[1, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_slope", "PANSSN_slope")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            y=y,
            x="age",
        )
        .add(so.Dot())
        .on(axs[2, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_slope", "PANSSN_slope")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        x="sex",
        y=y,
        ax=axs[3, i],
        xlabel="Sex",
    )

for i, y in enumerate(("PANSSP_intercept", "PANSSN_intercept")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            y=y,
            x="age",
        )
        .add(so.Dot())
        .on(axs[4, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_intercept", "PANSSN_intercept")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        x="sex",
        y=y,
        ax=axs[5, i],
        xlabel="Sex",
    )

#### Clinical Score Changes

In [None]:
num_sessions = pl.col("session").unique().len().over("subject")
topsy_df = (
    topsy.metadata.with_columns(pl.col.session.cast(int))
    .filter(
        pl.col("group") == "FEP",
        num_sessions > 1,
        ~pl.col("PANSSP", "PANSSN").is_null(),
        pl.col.session < 3,
    )
    .to_pandas()
)
jhp_df = (
    jhp.metadata.with_columns(pl.col.session.cast(int))
    .filter(
        pl.col("group") == "Patient",
        num_sessions > 1,
        ~pl.col("saps", "sans").is_null(),
    )
    .to_pandas()
)
with (ro.default_converter + ro.pandas2ri.converter).context():
    #     lm1 = lme4.lmer("PANSSP ~ (1|subject)", data=topsy_df)
    #     lm2 = lme4.lmer("PANSSP ~ session + (1|subject)", data=topsy_df)
    #     res = pbkrtest.PBmodcomp(lm2, lm1)
    #     print(res["test"])
    lm1 = lme4.lmer("saps ~ session + (1|subject)", data=jhp_df)
    lm2 = lme4.lmer("sans ~ session + (session|subject)", data=jhp_df)
    # res = pbkrtest.PBmodcomp(lm2, lm1)
    # print(res["test"])
    res = lmertest.ranova(lm2)
res

In [None]:
with (ro.default_converter + ro.pandas2ri.converter).context():
    res = lmertest.ranova(lm2)
res

In [None]:
#| cell-offset: -1
#| fig-cap: Baseline and follow-up clinical scores in early schizophrenia patients. At
#|   baseline, empty circles show subjects with no follow-ups. Filled circles connected
#|   by a line represent the same subject across multiple visits. No significant differences
#|   in baseline scores were found between subjects with and without follow-up visits.
#|   Trendlines show a linear fixed effect model of parameter against session with random
#|   slopes and intercepts fit for every subject (only random intercepts for TOPSY).
#|   Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals
#|   and random effects 500 times. In the TOPSY dataset, the PANSS8-P score was significantly
#|   lower in the second session than the first (1000 perms, p < .001). No other symptom
#|   scores significantly changed across session.
#| label: fig-hx
jitter = so.Jitter(width=0.2, seed=1)
fig = plt.figure(figsize=(8, 5), layout="constrained")

axs = fig.subplots(2, 2)

variables = np.array([["saps", "PANSSP"], ["sans", "PANSSN"]])
datasets = [
    jhp.metadata.filter(pl.col("group") == "Patient"),
    topsy.metadata.filter(pl.col("group") == "FEP", pl.col("session").cast(int) < 3),
]
ses_labels = [JHP_SESSIONS, TOPSY_SESSIONS]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}
formulae = ["y ~ x + (x|group)", "y ~ x + (1|group)"]

for x, y in np.ndindex(2, 2):
    variable = variables[x, y]
    label = labels[variable]
    dataset = datasets[y]
    ax = axs[x, y]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            dataset.with_columns(
                pl.col("session").cast(int),  # .replace(jhp_sessions),
                pl.when(num_sessions > 1).then(pl.col(variable)).name.prefix("multi_"),
                pl.when(num_sessions == 1)
                .then(pl.col(variable))
                .name.prefix("single_"),
            ).to_pandas(),
            x="session",
            y=f"multi_{variable}",
            group="subject",
        )
        .add(
            so.Line(color="#555555", linestyle="dashed", linewidth=1, alpha=0.3), jitter
        )
        .add(so.Line(linewidth=2), MLEFit())
        .add(so.Band(alpha=0.4), Lme4CI(formula=formulae[y], nsims=500))
        .add(so.Dot(color="#333333", edgewidth=0, alpha=0.5), jitter)
        .add(
            so.Dot(color="#333333", fill=False, alpha=0.8),
            so.Shift(x=-0.3),
            so.Jitter(width=0.2, seed=5),
            y=f"single_{variable}",
        )
        .scale(
            x=so.Continuous().tick(at=[1, 2, 3, 4]).label(like=ses_labels[y].get),
            y=so.Continuous().tick(every=4),
        )
        .label(y=label, x="Session")
        .on(ax)
        .plot()
    )


axs[0, 0].set_title("JHP", **Styles.col_title)
axs[0, 1].set_title("TOPSY", **Styles.col_title)
None

---
label: fig-hx
cell-offset: -1

---
Baseline and follow-up clinical scores in early schizophrenia patients. At baseline, empty circles show subjects with no follow-ups. Filled circles connected by a line represent the same subject across multiple visits. No significant differences in baseline scores were found between subjects with and without follow-up visits. Trendlines show a linear fixed effect model of parameter against session with random slopes and intercepts fit for every subject (only random intercepts for TOPSY). Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals and random effects 500 times. In the TOPSY dataset, the PANSS8-P score was significantly lower in the second session than the first (1000 perms, p < .001). No other symptom scores significantly changed across session.

In [None]:
#| fig-cap: Clinical scores from TOPSY patients across all sessions. Each dashed line
#|   corresponds to a different subject.
#| label: fig-topsy-full-hx
jitter = so.Jitter(width=0.2, seed=1)
fig = plt.figure(figsize=(8, 3), layout="constrained")

axs = fig.subplots(1, 2)

variables = np.array(["PANSSP", "PANSSN"])
datasets = [topsy.metadata.filter(pl.col("group") == "FEP")]
ses_labels = [TOPSY_SESSIONS]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}
formulae = ["y ~ x + (1|group)", "y ~ x + (x|group)"]

for x, y in np.ndindex(2, 1):
    variable = variables[x]
    label = labels[variable]
    dataset = datasets[0]
    ax = axs[x]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            dataset.with_columns(
                pl.col("session").cast(int),  # .replace(jhp_sessions),
                pl.when(num_sessions > 1).then(pl.col(variable)).name.prefix("multi_"),
                pl.when(num_sessions == 1)
                .then(pl.col(variable))
                .name.prefix("single_"),
            ).to_pandas(),
            x="session",
            y=f"multi_{variable}",
            group="subject",
        )
        .add(
            so.Line(color="#555555", linestyle="dashed", linewidth=1, alpha=0.3), jitter
        )
        .add(so.Dot(color="#333333", edgewidth=0, alpha=0.5), jitter)
        .scale(
            x=so.Continuous().tick(at=[1, 2, 3, 4]).label(like=ses_labels[y].get),
            y=so.Continuous().tick(every=4),
        )
        .label(y=label, x="Session")
        .on(ax)
        .plot()
    )



---
label: fig-topsy-full-hx
key: val
---

Clinical scores from TOPSY patients across all sessions. Each dashed line corresponds to a different subject.

### Gather white matter data

In [None]:
def get_wm_from_rois(path, wildcards, atlases):
    labels = list(
        it.chain.from_iterable(
            zip(it.repeat(i), range(np.max(atlas).astype(int) + 1))
            for i, atlas in enumerate(atlases)
        )
    )
    nlabels = len(labels)

    @layout_map(parallel=True, dims={"roi": nlabels}, dtype=float)
    def inner(path):
        data = nb.load(path).get_fdata()
        result = np.empty((nlabels,))
        try:
            for i in range(0, nlabels):
                atlas_ix, ix = labels[i]
                if ix == 0:
                    result[i] = 0
                    continue
                atlas = atlases[atlas_ix]

                result[i] = np.mean(data[atlas == ix])
            return result
        except:
            print(path)
            raise

    return inner(path, wildcards)


atlas_md = pl.read_csv("atlas-study_labels.csv")


def get_group_mapping():
    groups = atlas_md.filter(group="core")
    group_id, atlas_id = atlas_md.join(
        groups[["name"]], left_on="group", right_on="name"
    ).with_columns(
        pl.col.group.replace(
            dict(zip(*groups[["name", "atlas_id"]])), return_dtype=pl.Int32
        )
    )[["group", "atlas_id"]]
    mapping = np.zeros(((atlas_id.max() + 1),))
    mapping[atlas_id] = group_id
    return mapping


def get_atlases(layout, dims, jhu_atlas, lobe_atlas):
    @layout_map(parallel=True, dims=dims, dtype=float)
    def get_skeleton(path):
        return nb.load(path).get_fdata()

    mean_skeleton = (
        get_skeleton(layout, ["subject", "session"]).mean(["subject", "session"]) > 0
    )
    jhu_atlas, lobe_atlas = (
        np.where(mean_skeleton, nb.load(atlas).get_fdata(), 0)
        for atlas in (jhu_atlas, lobe_atlas)
    )
    core_groups = get_group_mapping().astype(int)
    lobe_mask = np.where(jhu_atlas == 0, lobe_atlas, 0)
    return {
        "lobe_mask": lobe_mask,
        "jhu_atlas": jhu_atlas,
        "global_mask": (jhu_atlas > 0) | (lobe_mask > 0),
        "core-periph": (jhu_atlas > 0).astype(int) + ((lobe_mask > 0).astype(int)) * 2,
        "core_group_mask": core_groups[jhu_atlas.astype(int)],
    }


def run_roi_sampling(layout, jhu_atlas, lobe_atlas, skeleton_dims):
    atlases = get_atlases(
        layout.get(suffix="skeletonized", desc="FA"),
        dims=skeleton_dims,
        jhu_atlas=jhu_atlas,
        lobe_atlas=lobe_atlas,
    )
    return get_wm_from_rois(
        layout.get(suffix="skeletonized", desc=["FA", "MD", "RD", "L1"]),
        ["subject", "session", "desc"],
        atlases=list(atlases.values()),
    )


from dask.diagnostics import ProgressBar

topsy_wm_sampled = run_roi_sampling(
    topsy.layout.get(suffix="skeletonized"),
    jhu_atlas="../topsy/code/jhp-atlas/atlas.nii.gz",
    lobe_atlas="../topsy/code/jhp-atlas/lobe-atlas.nii.gz",
    skeleton_dims={"x": 78, "y": 109, "z": 79},
)
jhp_wm_sampled = run_roi_sampling(
    jhp.layout.get(suffix="skeletonized"),
    jhu_atlas="../jhp/derivatives/atlases/atlas.nii.gz",
    lobe_atlas="../jhp/derivatives/atlases/lobe-atlas.nii.gz",
    skeleton_dims={"x": 248, "y": 295, "z": 93},
)

In [None]:
with ProgressBar():
    topsy_wm_sampled.to_netcdf("topsy_wm_sampled.nc")
    jhp_wm_sampled.to_netcdf("jhp_wm_sampled.nc")

In [None]:
skeleton_dims = {"x": 78, "y": 109, "z": 79}
mean_skeleton = (
    get_mean_skeleton(
        topsy.layout.get(suffix="skeletonized", desc="FA"),
        ["subject", "session"],
        dims=skeleton_dims,
    ).mean(["subject", "session"])
    > 0
)

In [None]:
import mcubes

jhu_atlas = "../topsy/code/jhp-atlas/atlas.nii.gz"
lobe_atlas = "../topsy/code/jhp-atlas/lobe-atlas.nii.gz"

jhu_mask = nb.load(jhu_atlas).get_fdata() > 0
jhu_mask &= mean_skeleton.load().data
# jhu_smoothed = mcubes.smooth(jhu_mask)
verts, faces = mcubes.marching_cubes(jhu_mask, 0)
mcubes.export_obj(verts, faces, "test.obj")

In [None]:
jhu_mask = nb.load(jhu_atlas).get_fdata() > 0
lobe_mask = nb.load(lobe_atlas).get_fdata() > 0
lobe_mask *= ~jhu_mask
lobe_mask &= mean_skeleton.load().data
# lobe_smoothed = mcubes.smooth(lobe_mask)
verts, faces = mcubes.marching_cubes(lobe_mask, 0)
mcubes.export_obj(verts, faces, "lobes.obj")

In [None]:
verts, faces = mcubes.marching_cubes(mean_skeleton.load().data, 0.5)
mcubes.export_obj(verts, faces, "skeleton.obj")

In [None]:
faces

In [None]:
topsy.layout

## Investigate!

In [None]:
topsy_wm_sampled = xr.open_dataarray("topsy_wm_sampled.nc", chunks={})
jhp_wm_sampled = xr.open_dataarray("jhp_wm_sampled.nc", chunks={})

In [None]:
atlas_md = pl.read_csv("atlas-study_labels.csv")

atlas_filters = [
    ~pl.col("label").is_in(
        [
            "Med",
            "Po",
            "Mb",
            "FTS",
        ]
    ),
    pl.col("group") != "cerebellar",
]

In [None]:
group_indices = dict(
    zip(*atlas_md.filter(pl.col("group").is_in(["core", "global"]))[["name", "index"]])
)


def get_index(group: str):
    return pl.lit(group_indices[group], dtype=pl.Int64)


def prepare_wm_rois(df, index):
    df = df.join(atlas_md.rename({"index": "roi"}), on="roi").filter(*atlas_filters)
    return (
        df.join(atlas_md.rename({"index": "roi"}), on="roi")
        .filter(*atlas_filters)
        .group_by(*index, "label")
        .agg(
            pl.col(
                "region",
                "hierarchy",
            ).first(),
            pl.col("data").mean(),
        )
    )

In [None]:
num_sessions = pl.col("session").unique().len().over("subject")


def all_sessions(da, hx):
    return (
        prepare_wm_rois(
            da.to_dataset(name="data")
            .to_dataframe()
            .dropna()
            .reset_index()
            .pipe(pl.from_pandas),
            ["subject", "session", "desc"],
        )
        .join(hx, on=["subject", "session"])
        .with_columns(num_sessions=num_sessions)
    )


topsy_df = all_sessions(
    topsy_wm_sampled,
    topsy.metadata.filter(
        pl.col("group").is_in(["FEP", "HC"]), pl.col("session").cast(int) < 3
    ),
)
jhp_df = all_sessions(
    jhp_wm_sampled,
    jhp.metadata.filter(pl.col("group").is_in(["Patient", "HC"])),
)

### Demographic Data

#### Group X Session Distribution

In [None]:
scs.chi2_contingency(
    jhp.metadata.filter(pl.col.num_sessions > 1, pl.col.session != "4")
    .group_by("session", "group")
    .agg(count=pl.len())
    .pivot(columns="group", index="session", values="count")
    .to_pandas()
    .set_index("session")
)

In [None]:
scs.chi2_contingency(
    jhp.metadata.with_columns(multises=pl.col.num_sessions > 1)
    .filter(pl.col.session == "1")
    .group_by("multises", "group")
    .agg(count=pl.len())
    .pivot(columns="group", index="multises", values="count")
    .to_pandas()
    .set_index("multises")
)

In [None]:
scs.chi2_contingency(
    topsy.metadata.with_columns(multises=num_sessions > 1)
    .filter(pl.col.session == "1", pl.col.group != "chronic")
    .group_by("multises", "group")
    .agg(count=pl.len())
    .pivot(columns="group", index="multises", values="count")
    .to_pandas()
    .set_index("multises")
)

#### JHP

In [None]:
jhp_included = [
    num_sessions > 1,
    pl.col.subject.is_in(jhp_df.filter(num_sessions > 1)["subject"]),
]

jhp_demo = jhp.metadata.filter(jhp_included)

In [None]:
def capitalize(label):
    return label.replace("_", " ").capitalize()


def prepare_session_table_jhp(table):
    table.add_nominal("sex", "{M}/{F}", autoformatter=capitalize)
    table.add_scale("age", autoformatter=capitalize)
    table.add_nominal("handedness", "{R}/{L}", autoformatter=capitalize)
    table.add_nominal("smoke", "{Yes}/{No}", name="Smoker")
    table.add_nominal("cannabis", "{Yes}/{No}", autoformatter=capitalize)
    table.add_scale(
        "doi",
        "Duration of Illness (weeks)",
        report="median",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale(
        "cpz",
        "CPZ (mg)",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale("saps", "SAPS", skip_stats=True, skip_fields=["HC"])
    table.add_scale("sans", "SANS", skip_stats=True, skip_fields=["HC"])


df = jhp_demo.with_columns(
    pl.col.handedness.fill_null("R").replace({"Left": "L", "Right": "R"}),
    pl.col.sex.replace({"Male": "M", "Female": "F"}),
    pl.col.cpz.fill_null(0),
    pl.col.doi / 12 * 52,
    pl.col.group.replace({"Patient": "EP"}),
)
jhp_n = dict(
    zip(
        *df[["group", "subject"]]
        .unique()
        .group_by("group")
        .len()
    )
)
parts = []
labels = [
    f"Healthy Control (n={jhp_n['HC']})",
    f"Early Psychosis (n={jhp_n['EP']})",
]
for session, label in JHP_SESSIONS.items():
    if session == 4:
        continue
    table = DemographicTable(
        df.filter(pl.col.session == str(session)).to_pandas(), "group", ["HC", "EP"], flavour = "latex"
    )

    prepare_session_table_jhp(table)

    table = (
        pl.from_pandas(table.to_pandas(significance=True).reset_index())
        .select(
            pl.col.index,
            cs.matches(r"HC \(n=.*\)"),
            pl.when(pl.col("HC vs EP sig"))
            .then(cs.matches("^EP.*").str.replace("^.*$", r"\textbf{$0}"))
            .otherwise(cs.matches("^EP.*")),
        )
        .to_pandas()
        .set_index("index")
    )
    table.columns = pd.MultiIndex.from_arrays(
        [
            labels,
            table.columns.map(
                lambda s: label + " " + re.search(r"\(n=.*\)", s).group(0)
            ),
        ]
    )
    parts.append(table)
table = DemographicTable(df.filter(pl.col.session == "4").to_pandas(), "group", ["EP"])
prepare_session_table_jhp(table)
table = table.to_pandas()
table.columns = pd.MultiIndex.from_arrays(
    [
        labels[1:],
        table.columns.map(lambda s: "3yr " + re.search(r"\(n=.*\)", s).group(0)),
    ]
)
parts.append(table)

print(
    pd.concat(parts, axis=1)
    .reindex(labels, axis=1, level=0)
    .style.to_latex(
        column_format="rlllllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-jhp-demographics",
        caption="JHP demographics.",
    )
)

In [None]:
df = jhp_demo.filter(pl.col.num_sessions > 1).with_columns(
    pl.col.handedness.fill_null("R").replace({"Left": "L", "Right": "R"}),
    pl.col.sex.replace({"Male": "M", "Female": "F"}),
    pl.col.cpz.fill_null(0),
    pl.col.doi / 12 * 52,
    pl.col.group.replace({"Patient": "EP"}),
)
parts = []
for session, label in JHP_SESSIONS.items():
    if session == 4:
        continue
    table = DemographicTable(
        df.filter(pl.col.session == str(session)).to_pandas(),
        "group",
        ["HC", "EP"],
        flavour="latex",
    )

    prepare_session_table_jhp(table)
    table = table.to_pandas()["HC vs EP"].rename(label)
    parts.append(table)

table = pd.concat(parts, axis=1)
table = table.loc[~(table == "").all(axis=1)]
print(
    table.style.to_latex(
        column_format="rlll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-jhp-demographics-stats",
        caption="JHP demographic statistics: HC vs patient.",
    )
)

In [None]:
df = jhp_demo.filter(pl.col.age < 24).with_columns(
    pl.col.handedness.fill_null("R").replace({"Left": "L", "Right": "R"}),
    pl.col.sex.replace({"Male": "M", "Female": "F"}),
    pl.col.cpz.fill_null(0),
    pl.col.doi / 12 * 52,
    pl.col.group.replace({"Patient": "EP"}),
)
jhp_n = dict(zip(*df[["group", "subject"]].unique().group_by("group").len()))
parts = []
labels = [
    f"Healthy Control (n={jhp_n['HC']})",
    f"Early Psychosis (n={jhp_n['EP']})",
]
for session, label in JHP_SESSIONS.items():
    if session == 4:
        continue
    table = DemographicTable(
        df.filter(pl.col.session == str(session)).to_pandas(),
        "group",
        ["HC", "EP"],
        flavour="latex",
    )

    prepare_session_table_jhp(table)

    table = (
        pl.from_pandas(table.to_pandas(significance=True).reset_index())
        .select(
            pl.col.index,
            cs.matches(r"HC \(n=.*\)"),
            pl.when(pl.col("HC vs EP sig"))
            .then(cs.matches("^EP.*").str.replace("^.*$", r"\textbf{$0}"))
            .otherwise(cs.matches("^EP.*")),
        )
        .to_pandas()
        .set_index("index")
    )
    table.columns = pd.MultiIndex.from_arrays(
        [
            labels,
            table.columns.map(
                lambda s: label + " " + re.search(r"\(n=.*\)", s).group(0)
            ),
        ]
    )
    parts.append(table)

print(
    pd.concat(parts, axis=1)
    .reindex(labels, axis=1, level=0)
    .style.to_latex(
        column_format="rllllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-jhp-demographics-agematch",
        caption="JHP age-matched demographics (age < 24yr).",
    )
)

In [None]:
df = jhp.metadata.filter(pl.col.num_sessions > 1, pl.col.age < 24).with_columns(
    pl.col.handedness.fill_null("R").replace({"Left": "L", "Right": "R"}),
    pl.col.sex.replace({"Male": "M", "Female": "F"}),
    pl.col.cpz.fill_null(0),
    pl.col.doi / 12 * 52,
    pl.col.group.replace({"Patient": "EP"}),
)
parts = []
for session, label in JHP_SESSIONS.items():
    if session == 4:
        continue
    table = DemographicTable(
        df.filter(pl.col.session == str(session)).to_pandas(),
        "group",
        ["HC", "EP"],
        flavour="latex",
    )

    prepare_session_table_jhp(table)
    table = table.to_pandas()["HC vs EP"].rename(label)
    parts.append(table)

table = pd.concat(parts, axis=1)
print(
    table.loc[~(table == "").all(axis=1)].style.to_latex(
        column_format="rlll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-jhp-demographics-agematch-stats",
        caption="JHP age-matched demographics statistics(age < 24yr).",
    )
)

In [None]:
def prepare_session_table_jhp(
    table,
    clinical=False,
):
    table.add_nominal("sex", "{M}/{F}", autoformatter=capitalize)
    table.add_scale("age", autoformatter=capitalize)
    table.add_nominal("handedness", "{R}/{L}", autoformatter=capitalize)
    table.add_nominal("smoke", "{Yes}/{No}", name="Smoker")
    table.add_nominal("cannabis", "{Yes}/{No}", autoformatter=capitalize)
    if clinical:
        table.add_scale(
            "doi",
            "Duration of Illness (weeks)",
            report="median",
        )
        table.add_scale(
            "cpz",
            "CPZ (mg)",
        )
        table.add_scale("saps", "SAPS")
        table.add_scale("sans", "SANS")


df = jhp.metadata.with_columns(
    pl.col.handedness.fill_null("R").replace({"Left": "L", "Right": "R"}),
    pl.col.sex.replace({"Male": "M", "Female": "F"}),
    pl.col.cpz.fill_null(0),
    pl.col.doi / 12 * 52,
    pl.col.group.replace({"Patient": "EP"}),
    multises=ft.reduce(op.and_, jhp_included),
).filter(pl.col.session == "1")
parts = []
groups = ["HC", "EP"]
clinical = [False, True]
labels = ["Healthy Control", "Early Psychosis"]
for i in range(2):
    table = DemographicTable(
        df.filter(group=groups[i]).to_pandas(),
        "multises",
        [False, True],
        flavour="latex",
    )

    prepare_session_table_jhp(table, clinical=clinical[i])

    table = table.to_pandas()
    table.columns = pd.MultiIndex.from_arrays(
        [
            [labels[i]] * 3,
            table.columns.map(
                lambda s: s.replace("False", "Dropout").replace("True", "Included")
            ),
        ]
    )
    parts.append(table)

print(
    pd.concat(parts, axis=1)
    .fillna("")
    .style.to_latex(
        column_format="rllllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-jhp-demographics-dropouts",
        caption="JHP dropout demographics.",
    )
)

#### TOPSY

In [None]:
topsy_included = [
    num_sessions > 1,
    pl.col.subject.is_in(topsy_df.filter(num_sessions > 1)["subject"]),
    pl.col.group != "chronic",
]

topsy_demo = topsy.metadata.filter(topsy_included)

In [None]:
def prepare_session_table_topsy(table):
    table.add_nominal("sex", "{M}/{F}", autoformatter=capitalize)
    table.add_scale("age", autoformatter=capitalize)
    table.add_nominal("handedness", "{R}/{L}/{A}", autoformatter=capitalize)
    table.add_scale("education", autoformatter=capitalize)
    table.add_scale("SES")
    table.add_scale("CAST")
    table.add_scale("AUDIT-C")
    table.add_nominal("smoker", "{yes}/{no}", autoformatter=capitalize)
    table.add_nominal("Cannabis", "{yes}/{no}", autoformatter=capitalize)
    table.add_scale("SOFAS")
    table.add_scale(
        "doi",
        "Duration of Illness (weeks)",
        report="median",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale(
        "ddd_dayofscan",
        "Antipsychotics (Defined Daily Dose)",
        report="median",
        skip_stats=True,
        skip_fields=["HC"],
    )
    table.add_scale("PANSSTOTAL", "PANSS-8 Total", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSSP", "PANSS-8 Positive", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSSN", "PANSS-8 Negative", skip_stats=True, skip_fields=["HC"])
    table.add_scale("PANSSG", "PANSS-8 General", skip_stats=True, skip_fields=["HC"])


parts = []
for session, label in TOPSY_SESSIONS.items():
    if session == 3:
        continue
    table = DemographicTable(
        topsy_demo
        .with_columns(
            pl.col.smoker.fill_null("no"),
            handedness=pl.when(pl.col("HandednessTotal") > 4)
            .then(pl.lit("R"))
            .when(pl.col("HandednessTotal") < -4)
            .then(pl.lit("L"))
            .otherwise(pl.lit("A")),
            doi=pl.col.time / 7 + pl.col.DUI,
        )
        .filter(session=str(session))
        .to_pandas(),
        "group",
        ["HC", "FEP"],
        flavour="latex",
    )
    prepare_session_table_topsy(table)
    table = (
        pl.from_pandas(table.to_pandas(significance=True).reset_index())
        .select(
            pl.col.index,
            cs.matches(r"HC \(n=.*\)"),
            pl.when(pl.col("HC vs FEP sig"))
            .then(cs.matches("^FEP.*").str.replace("^.*$", r"\textbf{$0}"))
            .otherwise(cs.matches("^FEP.*")),
        )
        .to_pandas()
        .set_index("index")
    )
    table.columns = pd.MultiIndex.from_arrays(
        [
            ["Healthy Control", "First Episode Psychosis"],
            table.columns.map(
                lambda s: label + " " + re.search(r"\(n=.*\)", s).group(0)
            ),
        ]
    )
    parts.append(table)
print(
    pd.concat(parts, axis=1)
    .reindex(["Healthy Control", "First Episode Psychosis"], axis=1, level=0)
    .style.to_latex(
        column_format="rllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-topsy-demographics",
        caption="TOPSY demographics.",
    )
)

In [None]:
parts = []
for session, label in TOPSY_SESSIONS.items():
    if session == 3:
        continue
    table = DemographicTable(
        topsy_demo
        .with_columns(
            pl.col.smoker.fill_null("no"),
            handedness=pl.when(pl.col("HandednessTotal") > 4)
            .then(pl.lit("R"))
            .when(pl.col("HandednessTotal") < -4)
            .then(pl.lit("L"))
            .otherwise(pl.lit("A")),
            doi=pl.col.time / 7 + pl.col.DUI,
        )
        .filter(session=str(session))
        .to_pandas(),
        "group",
        ["HC", "FEP"],
        flavour="latex",
    )

    prepare_session_table_topsy(table)
    table = table.to_pandas()["HC vs FEP"].rename(label)
    parts.append(table)

table = pd.concat(parts, axis=1)
print(
    table.loc[~(table == "").all(axis=1)]
    .fillna("")
    .style.to_latex(
        column_format="rll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-topsy-demographics-stats",
        caption="TOPSY demographic statistics.",
    )
)

In [None]:
def prepare_session_table_topsy(
    table,
    clinical=False,
):
    table.add_nominal("sex", "{M}/{F}", autoformatter=capitalize)
    table.add_scale("age", autoformatter=capitalize)
    table.add_nominal("handedness", "{R}/{L}/{A}", autoformatter=capitalize)
    table.add_scale("education", autoformatter=capitalize)
    table.add_scale("SES")
    table.add_scale("CAST")
    table.add_scale("AUDIT-C")
    table.add_nominal("smoker", "{yes}/{no}", autoformatter=capitalize)
    table.add_nominal("Cannabis", "{yes}/{no}", autoformatter=capitalize)
    table.add_scale("SOFAS")
    if clinical:
        table.add_scale(
            "doi",
            "Duration of Illness (weeks)",
            report="median",
        )
        table.add_scale(
            "ddd_dayofscan",
            "Antipsychotics (Defined Daily Dose)",
            report="median",
        )
        table.add_scale("PANSSTOTAL", "PANSS-8 Total")
        table.add_scale("PANSSP", "PANSS-8 Positive")
        table.add_scale("PANSSN", "PANSS-8 Negative")
        table.add_scale("PANSSG", "PANSS-8 General")


df = topsy.metadata.with_columns(
    pl.col.smoker.fill_null("no"),
    handedness=pl.when(pl.col("HandednessTotal") > 4)
    .then(pl.lit("R"))
    .when(pl.col("HandednessTotal") < -4)
    .then(pl.lit("L"))
    .otherwise(pl.lit("A")),
    doi=pl.col.time / 7 + pl.col.DUI,
    multises=ft.reduce(op.and_, topsy_included),
).filter(pl.col.session == "1", pl.col.group != "chronic")
parts = []
groups = ["HC", "FEP"]
clinical = [False, True]
labels = ["Healthy Control", "First Episode Psychosis"]
for i in range(2):
    table = DemographicTable(
        df.filter(group=groups[i]).to_pandas(), "multises", [False, True], flavour="latex"
    )

    prepare_session_table_topsy(table, clinical=clinical[i])

    table = table.to_pandas()
    table.columns = pd.MultiIndex.from_arrays(
        [
            [labels[i]] * 3,
            table.columns.map(
                lambda s: s.replace("False", "Dropout").replace("True", "Included")
            ),
        ]
    )
    parts.append(table)

print(
    pd.concat(parts, axis=1)
    .fillna("")
    .style.to_latex(
        column_format="rllllll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-topsy-demographics-dropouts",
        caption="TOPSY dropout demographics.",
    )
)

### Scan Interval

In [None]:
from numpy.random import default_rng


def since_inst():
    to_date = pl.col.date.str.to_date()
    return (to_date - to_date.min()).dt.total_days() / 365

def random_date(n: int, max: int, rng = None):
    rng = rng or default_rng()
    return (pl.lit(rng.random(n)) * max)


def target_time(base: int | pl.Expr, delta: float):
    ses_int = pl.col.session.str.to_integer()
    return (ses_int - ses_int.min().over("subject")) * delta + base

def get_target_scan_times(dataset, delta, offset=0):
    date = pl.col.date
    _target = target_time(pl.col.date.min().over("subject"), delta)
    return (
        dataset.with_columns(
            since_inst(),
        )
        .with_columns(
            start_rand=date  # + date.min().over("subject") * 5
            # - date.min().over("subject")
        )
        .with_columns(
            pl.when(pl.col.group == "HC")
            .then(_target + offset)
            .otherwise(_target - offset)
            .alias("target")
        )
        .filter(num_sessions > 1)
    )

In [None]:
deviation = (
    get_target_scan_times(jhp_demo, 1)
    .filter(pl.col.session == "2")
    .select("group", "subject", deviation=pl.col.date - pl.col.target)
    .to_pandas()
)
lm = smf.ols("deviation ~ group", data=deviation).fit()
print("Mean F/U", deviation["deviation"].mean() + 1, "+-", (deviation["deviation"] + 1).std())
print(lm.summary())
deviation = (
    get_target_scan_times(jhp_demo, 1)
    .filter(pl.col.session == "3")
    .select("group", "subject", deviation=pl.col.date - pl.col.target)
    .to_pandas()
)
print("Mean F/U", deviation["deviation"].mean() + 2, "+-", (deviation["deviation"] + 2).std())
lm = smf.ols("deviation ~ group", data=deviation).fit()
print(lm.summary())

In [None]:
deviation = (
    get_target_scan_times(topsy_demo.filter(pl.col.group != "chronic"), 0.5)
    .filter(pl.col.session == "2")
    .select("group", "subject", deviation=pl.col.date - pl.col.target)
    .to_pandas()
)
lm = smf.ols("deviation ~ group", data=deviation).fit()
print("Mean F/U", deviation["deviation"].mean() + 0.5, "+-", (deviation["deviation"] + 0.5).std())
print(lm.summary())

In [None]:
#| fig-cap: 'Actual scan dates versus target dates. Individual line segments connect
#|   the scan dates from individual subjects. Dates are given as time relative to study
#|   onset. Dashed, black line shows the expected slope of subjects scanned at the protocol-specified
#|   frequency: 1yr for JHP and 6 months for TOPSY. Shallower slopes reflect longer than
#|   expected scan-scan intervals; steeper represent shorter intervals.'
#| label: fig-longt-scan-interval
fig = plt.figure(figsize=(8, 3))
axs = fig.subplots(1, 2)
rng = default_rng(400)
deltas = [1, 0.5]
patient_grps = ["Patient", "FEP"]
datasets = [jhp_demo, topsy_demo.filter(pl.col.session.str.to_integer() < 3)]
titles = ["JHP", "TOPSY"]
offsets = [0, 0]
panel_labels = ["A", "B"]
for i in range(2):
    ax = axs[i]
    ax.text(
        -0.18,
        1.05,
        panel_labels[i],
        Styles.panel_label,
        transform=ax.transAxes
    )
    delta = deltas[i]
    patient_grp = patient_grps[i]
    offset = offsets[i]
    data = get_target_scan_times(datasets[i], delta)
    ax.set_title(titles[i])
    ax.plot(
        np.arange(data["date"].max() + 1) + offset, color="black", lw=0.5, linestyle="--"
    )
    ax.plot(
        np.arange(data["date"].max() + 1) - offset, color="black", lw=0.5, linestyle="--"
    )
    (
        so.Plot(
            data,
            x="date",
            y="target",
            group="subject",
            color="group",
        )
        .add(so.Line(alpha=0.8, linestyle="-", linewidth=0.3), legend=False)
        # .add(so.Dot())
        .scale(
            color=so.Nominal(order=["HC", patient_grp]),
            x=so.Continuous().tick(every=1).label(like="{x:.0f}y"),
            y=so.Continuous().tick(every=1).label(like="{x:.0f}y"),
        )
        .label(x="Actual Date", y="Planned Date" if i == 0 else None, title=titles[i])
        .on(ax)
        .plot()
    )

add_legend(
    fig,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=12,
    bbox_to_anchor=(0.73, 0.86),
    linestyle="-",
    markers=False,
)

---
label: fig-longt-scan-interval
key: val
---

Actual scan dates versus target dates. Individual line segments connect the scan dates from individual subjects. Dates are given as time relative to study onset. Dashed, black line shows the expected slope of subjects scanned at the protocol-specified frequency: 1yr for JHP and 6 months for TOPSY. Shallower slopes reflect longer than expected scan-scan intervals; steeper represent shorter intervals.

### Longitudinal Analysis

In [None]:
with (ro.default_converter + ro.pandas2ri.converter).context():
    #     lm1 = lme4.lmer("PANSSP ~ (1|subject)", data=topsy_df)
    #     lm2 = lme4.lmer("PANSSP ~ session + (1|subject)", data=topsy_df)
    #     res = pbkrtest.PBmodcomp(lm2, lm1)
    #     print(res["test"])
    # lm1 = lme4.lmer(
    #     "data ~ session*group + age + sex + (1|subject)",
    # )
    # lm2 = lme4.lmer("data ~ session + (session|subject)", data=jhp_df)
    # res = pbkrtest.PBmodcomp(lm2, lm1)
    summ = ro.r("""
    function (df) {
        lm1 <- lmer(scale(data) ~ session * group + age + sex + (1|subject), data=df)
        print(summary(lm1))
    }
    """)(
        jhp_df.filter(pl.col.num_sessions > 1, pl.col.age < 24, label="WM", desc="RD")
        .with_columns(pl.col.session.cast(int).sub(1))
        .to_pandas()
    )
    # print(res["test"])
    # res = lmertest.ranova(lm2)

In [None]:
# | cell-offset: -1
# | fig-cap: Global longitudinal changes of white matter microstructure in early schizophrenia
# |   patients. Trendlines show a linear mixed effect model of parameter against session
# |   with random intercepts fit for every subject. Shaded bands show a 95% CI computed
# |   with parametric bootstrapping resampling residuals and random effects 1000 times.
# |   No significant differences were found between the slopes of HCs and patients in
# |   either dataset for any of the parameters measured. In the JHP sample, fitting random
# |   slopes to each subject did not significantly improve the fit of the model (not tested
# |   in TOPSY because each subject had only two time points).
# | fig-env: FPfigure
# | label: fig-longt
jitter = so.Jitter(seed=3009)
fig = plt.figure(figsize=(8, 10), layout="constrained")
axs = fig.subplots(4, 2)

variables = np.array(["FA", "MD", "RD", "L1"])
dfs = [
    jhp_df,
    topsy_df,
]
ses_labels = [JHP_SESSIONS, TOPSY_SESSIONS]
orders = [["HC", "Patient"], ["HC", "FEP"]]
yscales = [
    so.Continuous(),
    *([so.Continuous().label(like=lambda x, _: f"{x*1000:.2f}")] * 3),
]
units = ["", *([r"$\frac{\mu m^2}{ms}$"] * 3)]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
    "L1": "AD",
}

for x, y in np.ndindex(4, 2):
    variable = variables[x]
    label = labels.get(variable, variable)
    labely = f"{label} ({units[x]})" if units[x] else label
    df = dfs[y]
    ax = axs[x, y]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            df.filter(
                pl.col.label == "WM", pl.col.desc == variable, pl.col.num_sessions > 1
            )
            .with_columns(pl.col.session.cast(int))
            .to_pandas(),
            x="session",
            y="data",
            group="subject",
            color="group",
        )
        .add(
            so.Line(linestyle="dashed", linewidth=0.5, alpha=0.5), jitter, legend=False
        )
        .add(so.Line(linewidth=2), MLEFit(), legend=False)
        .add(so.Band(), Lme4CI(nsims=1000), legend=False)
        .add(so.Dot(edgewidth=0, alpha=0.5, pointsize=3), jitter, legend=False)
        .scale(
            x=so.Continuous().tick(at=[1, 2, 3, 4]).label(like=ses_labels[y].get),
            y=yscales[x],
            color=so.Nominal(order=orders[y]),
        )
        .label(y=None if y > 0 else labely, x=None if x < 3 else "Session")
        .on(ax)
        .plot()
    )
axs[0, 0].set_title("JHP", **Styles.col_title)
axs[0, 1].set_title("TOPSY", **Styles.col_title)
add_legend(
    fig,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=12,
    bbox_to_anchor=(0.73, 0.26),
    linestyle="-",
    band=True,
)
None

---
label: fig-longt
fig-env: FPfigure
cell-offset: -1

---
Global longitudinal changes of white matter microstructure in early schizophrenia patients. Trendlines show a linear mixed effect model of parameter against session with random intercepts fit for every subject. Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals and random effects 1000 times. No significant differences were found between the slopes of HCs and patients in either dataset for any of the parameters measured. In the JHP sample, fitting random slopes to each subject did not significantly improve the fit of the model (not tested in TOPSY because each subject had only two time points).

In [None]:
#| fig-cap: Global longitudinal changes of white matter microstructure in age-matched
#|   subset of JHP dataset. Trendlines show a linear mixed effect model of parameter
#|   against session with random intercepts fit for every subject. Shaded bands show
#|   a 95% CI computed with parametric bootstrapping resampling residuals and random
#|   effects 1000 times. No significant differences were found between the slopes of
#|   HCs and patients for any of the parameters measured. Fitting random slopes to each
#|   subject did not significantly improve the fit of the model.
#| label: fig-longt-agematched
jitter = so.Jitter(seed=3009)
fig = plt.figure(figsize=(8, 5), layout="constrained")
axs = fig.subplots(2, 2)

variables = np.array(["FA", "MD", "RD", "L1"]).reshape(2, 2)
ses_labels = [TOPSY_SESSIONS, JHP_SESSIONS]
yscales = np.array(
    [
        so.Continuous(),
        *([so.Continuous().label(like=lambda x, _: f"{x*1000:.2f}")] * 3),
    ]
).reshape(2, 2)
units = np.array(["", *([r"$\frac{\mu m^2}{ms}$"] * 3)]).reshape(2, 2)

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
    "L1": "AD",
}

for x, y in np.ndindex(2, 2):
    variable = variables[x, y]
    label = labels.get(variable, variable)
    labely = f"{label} ({units[x, y]})" if units[x, y] else label
    df = jhp_df.filter(pl.col.age < 24)
    ax = axs[x, y]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            df.filter(
                pl.col.label == "WM", pl.col.desc == variable, pl.col.num_sessions > 1
            )
            .with_columns(pl.col.session.cast(int))
            .to_pandas(),
            x="session",
            y="data",
            group="subject",
            color="group",
        )
        .add(
            so.Line(linestyle="dashed", linewidth=0.5, alpha=0.5), jitter, legend=False
        )
        .add(so.Line(linewidth=2), MLEFit(), legend=False)
        .add(so.Band(), Lme4CI(nsims=1000), legend=False)
        .add(so.Dot(edgewidth=0, alpha=0.5, pointsize=3), jitter, legend=False)
        .scale(
            x=so.Continuous().tick(at=[1, 2, 3, 4]).label(like=JHP_SESSIONS.get),
            y=yscales[x, y],
            color=so.Nominal(order=["HC", "Patient"]),
        )
        .label(y=labely, x=None if x < 1 else "Session")
        .on(ax)
        .plot()
    )
add_legend(
    fig,
    ["HC", "Patient"],
    cmap=cm.tab10,
    fontsize=12,
    bbox_to_anchor=(0.75, 0.52),
    linestyle="-",
    band=True,
)
None

---
label: fig-longt-agematched
key: val
---

Global longitudinal changes of white matter microstructure in age-matched subset of JHP dataset. Trendlines show a linear mixed effect model of parameter against session with random intercepts fit for every subject. Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals and random effects 1000 times. No significant differences were found between the slopes of HCs and patients for any of the parameters measured. Fitting random slopes to each subject did not significantly improve the fit of the model.

### Correlations with Symptoms

In [None]:
def avg_sessions(da, hx):
    return prepare_wm_rois(
        pl.from_pandas(da.mean("session").to_dataframe(name="data").reset_index()),
        ["subject", "desc"],
    ).join(hx, on="subject")


topsy_df = avg_sessions(topsy_wm_sampled, topsy_hx)
jhp_df = avg_sessions(jhp_wm_sampled, jhp_hx)

In [None]:
data = jhp_df.filter(
    pl.col.desc == "FA",
    pl.col.num_sessions > 1,
    label="WM",
)

df = data.to_pandas().apply(pd.Series)
model = smf.ols("data ~ sans_intercept + age + sex", data=df).fit()

contr = model.t_test("sans_intercept")
model.summary()
# token = itx.one(n for n in model.model.exog_names if col in n)
# return {
#     "intercept": model.params.loc["Intercept"],
#     "beta": model.params.loc[token],
#     "pval": _1tail(contr.pvalue, contr.statistic, data[0]["desc"], col),
#     "statistic": contr.statistic,
#     "nobs": model.nobs,
#     "df_model": model.df_model,
#     "df_resid": model.df_resid,
# }

In [None]:
def _1tail(pval, stat, desc, param):
    val = 1
    if desc in {"FA", "thickness"}:
        val *= -1
    if "recovery" in param:
        val *= -1
    val *= np.sign(stat)
    return (min(0, val) * -1) + (pval * val / 2)


def do_stats(col):
    def inner(data):
        df = data.to_pandas().apply(pd.Series)
        model = smf.ols(f"data ~ {col} + age + sex", data=df).fit()

        token = itx.one(n for n in model.model.exog_names if col in n)
        contr = model.t_test(token)
        return {
            "intercept": model.params.loc["Intercept"],
            "beta": model.params.loc[token],
            "pval": _1tail(contr.pvalue, contr.statistic, data[0]["desc"], col),
            "statistic": contr.statistic,
            "nobs": model.nobs,
            "df_model": model.df_model,
            "df_resid": model.df_resid,
        }

    dtype = pl.Struct(
        {
            "intercept": pl.Float64,
            "beta": pl.Float64,
            "pval": pl.Float64,
            "statistic": pl.Float64,
            "nobs": pl.Float64,
            "df_model": pl.Float64,
            "df_resid": pl.Float64,
        }
    )
    return (
        pl.struct("desc", "data", col, "age", "sex")
        .map_elements(inner, return_dtype=dtype)
        .alias(f"{col}_stats")
    )


def pearsonr(col):
    def inner(data):
        result = scs.pearsonr(
            np.asarray(data.struct[col]), np.asarray(data.struct["data"])
        )
        return result.statistic

    return (
        pl.struct("data", col)
        .filter(~pl.col(col).is_null())
        .map_elements(inner, return_dtype=pl.Float64)
        .alias(f"{col}_pearsonr")
    )


def get_pvals(col):
    return do_stats(col).name.prefix_fields(f"{col}_")  # .struct[f"{col}_pval"]


def get_all_stats(df, scores, features):
    return (
        df.filter(pl.col("num_sessions") > 1)
        .group_by("label", "desc")
        .agg(
            *(
                do_stats(f"{score}_{suffix}")
                for score, suffix in it.product(scores, features)
            ),
            *(
                pearsonr(f"{score}_{suffix}")
                for score, suffix in it.product(scores, features)
            ),
            pl.first("hierarchy"),
        )
        .melt(["label", "desc", "hierarchy"], cs.matches(".*_stats"), "param", "stats")
        .with_columns(
            pl.col("param").str.split("_").list.to_struct(fields=["score", "feature"])
        )
        .unnest("param", "stats")
        .with_columns(
            pl.col("pval")
            .map_elements(
                lambda x: pl.Series(scs.false_discovery_control(x)),
                return_dtype=pl.List(pl.Float64),
            )
            .over("hierarchy", "desc", "score", "feature")
            .name.suffix("corr")
        )
    )


topsy_stats = get_all_stats(
    topsy_df,
    ["PANSSP", "PANSSN"],
    ["recovery"],
    # The rest are not significant
    # ["baseline", "mean", "slope", "intercept", "recovery"],
)
jhp_stats = get_all_stats(
    jhp_df,
    ["saps", "sans"],
    ["baseline", "mean", "intercept"],
    # The rest are not significant
    # ["baseline", "mean", "slope", "intercept"],
)

In [None]:
sig = (
    jhp_stats.filter(pl.col("pvalcorr") < 0.05)
    # .with_columns(parameter=pl.concat_str("score", "feature", separator="_"))
    .join(
        # atlas_md.rename({"index": "roi"}),
        atlas_md.group_by("label").agg(pl.first("region")),
        on="label",
    )
    .pivot(
        values="pvalcorr",
        index=["region", "desc", "hierarchy", "feature"],
        columns="score",
    )
)
# sig.write_clipboard()

In [None]:
df_r = (
    df.filter(
        pl.col("num_sessions") > 1,
        pl.col("label") == "WM",
        pl.col("desc") == "L1",
    )
    .group_by("label", "subject")
    .agg(pl.col("*").exclude("data").first(), pl.col("data").mean())
)
lm1 = smf.ols("data ~ Q('PANSSN_recovery')", data=df_r.to_pandas()).fit()
lm2 = smf.ols("data ~ Q('PANSSN_recovery') + age", data=df_r.to_pandas()).fit()
lm2.compare_f_test(lm1)

In [None]:
#| cell-offset: -1
#| fig-cap: Correlation between microstructural parameters and SANS intercept. Microstructural
#|   measurements are averaged across all sessions for each subject. The SANS intercept
#|   was computed using a first order linear model for each subject, with the baseline
#|   scan as time 0. Shaded bands show 95% CI computed with nonparametric bootstrap paired
#|   resampling with 1000 permutations. Relationships were tested with a linear model
#|   with age and sex and covariates. Solid lines represent statistically significant
#|   relationships, dashed lines are nonsignificant. T-values and P-values are shown
#|   in @tbl-jhp-intercept.
#| fig-env: FPfigure
#| label: fig-jhp-roi
fig = plt.figure(figsize=(8, 10), layout="constrained")
axs = fig.subfigures(4, 1)

variables = np.array(["FA", "MD", "RD", "L1"])
yscales = [
    so.Continuous(),
    *([so.Continuous().label(like=lambda x, _: f"{x*1000:.2f}")] * 3),
]
units = ["", *([r"$\frac{nm^2}{ms}$"] * 3)]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}

for y, x in np.ndindex(4, 1):
    variable = variables[y]
    label = labels.get(variable, variable)
    labely = f"{label} ({units[y]})" if units[y] else label
    ax = axs[y]
    # num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            jhp_df.filter(pl.col("hierarchy") < 2, pl.col.desc == variable)
            .join(
                jhp_stats.filter(
                    pl.col.desc == variable, feature="intercept", score="sans"
                )[["label", "pvalcorr", "desc"]],
                on="label",
            )
            .with_columns(sig=pl.col.pvalcorr < 0.05),
            y="data",
            x="sans_intercept",
        )
        .facet(
            col="region",
            order=["White Matter", "Core White Matter", "Peripheral White Matter"],
        )
        .share(y=False)
        .add(so.Dot(color="#555555"))
        .add(so.Line(), so.PolyFit(1))
        .add(so.Band(), PolyCI(ci=95, nsims=1000), legend=False)
        .add(
            so.Text({"ha": "right" if variable == "FA" else "left"}),
            PearsonrAnnot("upper right" if variable == "FA" else None),
        )
        .label(
            title="" if y > 0 else None, x=None if y < 3 else "SANS Intercept", y=labely
        )
        .scale(
            y=yscales[y],
            color=["#88c", "#555"],
        )
        .on(ax)
        .plot()
    )

    for label in ["WM", "CWM", "PWM"]:
        if (
            jhp_stats.filter(
                score="sans", feature="intercept", desc=variable, label=label
            )["pvalcorr"][0]
            >= 0.05
        ):
            ax.get_children()[i].get_children()[2].set_facecolor((0, 0, 0, 0.2))
            ax.get_children()[i].get_children()[1].set_color((0, 0, 0, 0.8))
            ax.get_children()[i].get_children()[1].set_linestyle(":")

---
label: fig-jhp-roi
fig-env: FPfigure
cell-offset: -1

---
Correlation between microstructural parameters and SANS intercept. Microstructural measurements are averaged across all sessions for each subject. The SANS intercept was computed using a first order linear model for each subject, with the baseline scan as time 0. Shaded bands show 95% CI computed with nonparametric bootstrap paired resampling with 1000 permutations. Relationships were tested with a linear model with age and sex and covariates. Solid lines represent statistically significant relationships, dashed lines are nonsignificant. T-values and P-values are shown in @tbl-jhp-intercept.

### Heatmap of other regions

In [None]:
from nilearn import plotting
from scipy.ndimage import binary_dilation

img = nb.load(jhp.layout.get(suffix="skeletonized", desc="FA")[0])
skeleton_mask = binary_dilation(img.get_fdata() > 0)

jhu_atlas = np.where(
    skeleton_mask,
    nb.load("../jhp/derivatives/atlases/atlas.nii.gz").get_fdata(),
    0,
)
lobe_atlas = np.where(
    img.get_fdata() > 0,
    nb.load("../jhp/derivatives/atlases/lobe-atlas.nii.gz").get_fdata(),
    0,
)
# core_mask = (jhu_atlas > 0).astype(int)
# periph_mask = ((lobe_atlas > 0) & (jhu_atlas == 0)).astype(int)
lobe_mask = np.where(jhu_atlas == 0, lobe_atlas, 0)

In [None]:
def get_level2_atlas():
    tract_name_to_id = dict(zip(*atlas_md.filter(group="core")[["name", "index"]]))
    tract_id, jhu_id = atlas_md.filter(
        pl.col.hierarchy == 3, ~pl.col.group.is_in({"unclassified", "cerebellar"})
    ).select(
        pl.col.group.replace(tract_name_to_id).cast(int),
        pl.col.atlas_id,
    )
    atlas_vals = np.zeros(jhu_atlas.max().astype(int) + 1)
    np.put(atlas_vals, jhu_id, tract_id)
    return np.asarray(lobe_mask + atlas_vals[jhu_atlas.astype(int)], dtype=int)


def project_level3(df):
    atlas_ids, stats = atlas_md.join(df, on="label").filter(
        pl.col.pvalcorr < 0.05,
        pl.col.hierarchy == 3,
    )[["atlas_id", "statistic"]]
    atlas_vals = np.zeros(jhu_atlas.max().astype(int) + 1)
    np.put(atlas_vals, atlas_ids, stats)
    return atlas_vals[jhu_atlas.astype(int)]


def project_level2(df):
    atlas_ids, stats = atlas_md.join(df, on="label").filter(
        pl.col.pvalcorr < 0.05,
        pl.col.hierarchy == 2,
    )[["index", "statistic"]]
    atlas = get_level2_atlas()
    atlas_vals = np.zeros(atlas.max() + 1)
    np.put(atlas_vals, atlas_ids, stats)

    return atlas_vals[atlas]

In [None]:
#| cell-offset: -1
#| fig-cap: Correlations between microstructural parameters and SANS intercept. Microstructural
#|   measures and SANS intercpets were computed as in @fig-jhp-roi. Relationships were
#|   tested with a linear model with age and sex and covariates. Significant ROIs are
#|   coloured according to their T-value. A and B show two nested hierarchical layers,
#|   with B at a finer resolution. Comparisons within each layer were corrected for multiple
#|   comparisons using FDR. T-values and P-values are shown in @tbl-jhp-intercept.
#| label: fig-jhp-heatmap
vmin, vmax = jhp_stats.filter(
    pl.col.pvalcorr < 0.05,
    pl.col.hierarchy.is_in([2, 3]),
    pl.col.feature == "intercept",
).select(min=pl.min("statistic"), max=pl.max("statistic"))

fig = plt.figure(figsize=(8, 8), facecolor="black")
bg = "../jhp/derivatives/tpl-fa/tpl-study/tpl-study_FA.nii.gz"

descs = ["MD", "RD", "L1"]
panel_labels = ["A", "B"]

*panels, gutter = fig.subfigures(3, 1, height_ratios=[10, 10, 0.5])
panels[0].suptitle(
    "Peripheral and Core Groups", **{**Styles.col_title, "color": "#ffffff"}
)
panels[1].suptitle("JHU ROIs", **{**Styles.col_title, "color": "#ffffff"})
for i, project in enumerate([project_level2, project_level3]):
    panel = panels[i]
    panel.text(
        0.1,
        0.95,
        panel_labels[i],
        **(Styles.panel_label | {"color": "white"}),
    )
    axs = panel.subplots(3, 1)

    for y in range(3):
        if y == 0:
            axs[y].text(0, 1, "L", color="white")
            axs[y].text(1, 1, "R", color="white", ha="right")
        param_map = nb.Nifti1Image(
            project(jhp_stats.filter(desc=descs[y], feature="intercept")),
            img.affine,
            img.header,
        )
        plotting.plot_stat_map(
            param_map,
            bg,
            cut_coords=np.r_[1:45:7j],
            resampling_interpolation="nearest",
            display_mode="z",
            cmap="autumn",
            symmetric_cbar=False,
            vmin=vmin[0],
            vmax=vmax[0],
            annotate=False,
            colorbar=False,
            axes=axs[y],
            figure=panel,
        )
        axs[y].axis("on")
        axs[y].get_yaxis().set_ticks([])
        axs[y].get_xaxis().set_visible(False)
        axs[y].set_ylabel(
            descs[y],
            **{
                **Styles.row_title,
                "color": "#ffffff",
                "ha": "center",
                "fontweight": "bold",
            },
        )
guttergrid = gutter.add_gridspec(1, 3, width_ratios=[10, 80, 10])
cbar = gutter.add_subplot(guttergrid[1])
cb = add_colorbar(
    vmin, vmax, ax=cbar, cmap="autumn", orientation="horizontal", outline=False
)
cbar.xaxis.label.set_color("white")
cbar.tick_params(axis="x", colors="white", labelsize=10)
cbar.set_xlabel("T-value", color="white", size=12)
None

---
label: fig-jhp-heatmap
cell-offset: -1

---
Correlations between microstructural parameters and SANS intercept. Microstructural measures and SANS intercpets were computed as in @fig-jhp-roi. Relationships were tested with a linear model with age and sex and covariates. Significant ROIs are coloured according to their T-value. A and B show two nested hierarchical layers, with B at a finer resolution. Comparisons within each layer were corrected for multiple comparisons using FDR. T-values and P-values are shown in @tbl-jhp-intercept.

### TOPSY correlations

In [None]:
vmin, vmax = topsy_stats.filter(
    pl.col.pvalcorr < 0.05,
    pl.col.hierarchy.is_in([2, 3]),
    pl.col.feature == "recovery",
).select(min=pl.min("statistic"), max=pl.max("statistic"))
param_map = nb.Nifti1Image(
    project_level2(topsy_stats.filter(feature="recovery", desc="FA", score="PANSSN")),
    img.affine,
    img.header,
)
bg = "../jhp/derivatives/tpl-fa/tpl-study/tpl-study_FA.nii.gz"
cuts = np.array([[10, 25], [35, 45]])
imgs = np.zeros((2, 2), dtype=object)
for y, x in np.ndindex(2, 2):
    fig = plt.figure(figsize=(5, 5), facecolor="black")
    ax = fig.subplots(1, 1)
    plotting.plot_stat_map(
        param_map,
        bg,
        cut_coords=[cuts[y, x]],
        resampling_interpolation="nearest",
        display_mode="z",
        cmap="autumn",
        symmetric_cbar=False,
        vmin=vmin[0],
        vmax=vmax[0],
        annotate=False,
        colorbar=False,
        axes=ax,
    )
    imgs[y, x] = fig_to_numpy(fig)
    plt.close()
topsy_tmap = np.vstack([np.hstack([a[150:1300, 300:1200] for a in r]) for r in imgs])
plt.imshow(topsy_tmap)

In [None]:
topsy_stats.filter(pl.col.pvalcorr < 0.05)

In [None]:
topsy_df

In [None]:
#| cell-offset: -1
#| fig-cap: 'Correlations between microstructural parameters and the PANSS8-N follow-up
#|   score. Microstructural measures are computed as in @fig-jhp-roi. Subjects are grouped
#|   based on whether their PANSS8-N score at their follow-up session was equal to 3,
#|   the lowest possible socre. Relationships were tested with a linear model with age
#|   and sex and covariates. All comparisons shown are signficant. A, B, C: ROIs from
#|   different nested hierarchical layers at successively higher resolutions. D: the
#|   location of the regions in C with their T-values. Comparisons within each layer
#|   were corrected for multiple comparisons using FDR. T-values and P-values are shown
#|   in @tbl-topsy-hx.'
#| label: fig-topsy-roi
fig = plt.figure(figsize=(8, 6), layout="constrained")
grid = fig.add_gridspec(4, 3)

slots = [grid[0, 0], grid[0, 1], grid[0, 2], grid[1, 0], grid[2, 0], grid[3, 0]]

sig_topsy = topsy_df.join(
    topsy_stats[["desc", "label", "pvalcorr"]], on=["label", "desc"]
).filter(pl.col.pvalcorr < 0.05, pl.col.num_sessions > 1)

regions = (
    sig_topsy.group_by("region")
    .agg(pl.first("hierarchy"))
    .sort("hierarchy", "region")["region"]
    .to_numpy()
)

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
}

panels = ["A", "B", None, "C", None, None]

for i in range(6):
    region = regions[i]
    ax = fig.add_subplot(slots[i])
    ax.set_title(region)
    ax.set_xticks([False, True], labels=[">3", "3"])
    if (panel := panels[i]) is not None:
        ax.text(
            -0.25,
            1,
            panel,
            transform=ax.transAxes,
            **Styles.panel_label,
        )

    comparison_plot(
        sig_topsy.filter(region=region), x="PANSSN_recovery", y="data", ax=ax
    )
    if i in {5, 1, 2}:
        ax.set_xlabel("Follow-up PANSS8-N Score", size=10)
    else:
        ax.set(xlabel=None)
    if i in {0, 3, 4, 5}:
        ax.set_ylabel("FA")
    else:
        ax.set_ylabel(None)


fig.patches.append(
    plt.Rectangle(
        (0.43, 0.02), 0.57, 0.7, color="black", transform=fig.transFigure, zorder=-1
    )
)
atlas_panel = fig.add_subfigure(grid[1:, 1:])
atlas_panel.text(
    -0.05,
    1,
    "D",
    **Styles.panel_label,
)
axs = atlas_panel.subplots(1, 2, width_ratios=[15, 1])
axs[0].imshow(topsy_tmap)
axs[0].axis("off")
axs[0].set_position(axs[0].get_position().expanded(1.3, 1.3))
axs[0].text(
    0, 0.45, "L", color="white", size=14, fontweight="bold", transform=axs[0].transAxes
)
axs[0].text(
    1,
    0.45,
    "R",
    ha="right",
    color="white",
    size=14,
    fontweight="bold",
    transform=axs[0].transAxes,
)

add_colorbar(vmin, vmax, ax=axs[1], cmap="autumn", outline=False)
axs[1].set_ylabel("T-value")
axs[1].yaxis.label.set_color("white")
axs[1].yaxis.set_major_formatter(FormatStrFormatter("%0.02f"))
axs[1].tick_params(axis="y", colors="white", labelsize=10)

---
label: fig-topsy-roi
cell-offset: -1

---
Correlations between microstructural parameters and the PANSS8-N follow-up score. Microstructural measures are computed as in @fig-jhp-roi. Subjects are grouped based on whether their PANSS8-N score at their follow-up session was equal to 3, the lowest possible socre. Relationships were tested with a linear model with age and sex and covariates. All comparisons shown are signficant. A, B, C: ROIs from different nested hierarchical layers at successively higher resolutions. D: the location of the regions in C with their T-values. Comparisons within each layer were corrected for multiple comparisons using FDR. T-values and P-values are shown in @tbl-topsy-hx.

In [None]:
topsy_modelling = BidsLayout("../topsy/derivatives/models-v0.1.0")

In [None]:
import h5py


@layout_map(
    dims={"src": atlases.bn246["Label ID"], "dest": atlases.bn246["Label ID"]},
    dtype=int,
)
def load_h5_nbs(file):
    with h5py.File(file, "r") as f:
        try:
            networks = f["nbs/con_mat"][:]
            pvals = f["nbs"].attrs["pval"][0]
            # sizes = np.sum(networks.reshape(-1, networks.shape[-1]), axis=0)
            # labels = np.argsort(sizes)[::-1] + 1
            # return np.sum(networks * labels, axis=-1)
            sig = np.sum(networks[..., pvals < 0.05], axis=-1)
        except IndexError:
            sig = 0
        return f["nbs/test_stat"][:] * sig


descs = [
    "".join(x)
    for x in it.product(
        ["AD", "FA", "MD", "RD"], ["p", "n"], ["intercept", "recov", "slope"]
    )
]
modelling_paths = topsy_modelling.get(suffix="nbs", label=["pos", "neg"], desc=descs)
topsy_nbs = load_h5_nbs(modelling_paths, ["desc"])

In [None]:
def get_nbs_cluster_stats(file):
    with h5py.File(file, "r") as f:
        try:
            networks = f["nbs/con_mat"][:]
            pvals = f["nbs"].attrs["pval"][0]
            sizes = np.sum(networks.reshape(-1, networks.shape[-1]), axis=0)
            return [
                {"i": i, "desc": file.entities["desc"], "pval": pval, "size": size}
                for i, (pval, size) in enumerate(zip(pvals, sizes))
            ]
        except IndexError:
            return []


topsy_nbs_stats = pl.DataFrame(
    it.chain.from_iterable(get_nbs_cluster_stats(f) for f in modelling_paths)
)

In [None]:
models = (
    (topsy_nbs > 0)
    .sum(["src", "dest"])
    .to_dataframe(name="count")
    .reset_index()
    .pipe(pl.from_pandas)
    .with_columns(
        param=pl.col.desc.str.slice(0, 2),
        score=pl.col.desc.str.slice(2, 1).replace({"n": "PANSSN", "p": "PANSSP"}),
        term=pl.col.desc.str.slice(3),
    )
)

In [None]:
#| fig-cap: Networks associated with negative symptoms in the TOPSY dataset. Microstructural
#|   measures are averaged across sessions. Values for each connection were measured
#|   by sampling along the constituent streamlines. The recovery score is computed as
#|   in @fig-topsy-roi. In each network diagram, lines represent connections significantly
#|   correlated with the corresponding PANSS8-N derivative measure as determined using
#|   NBS (10,000 samples, $T_{thresh}=3$, FWER $<0.5$). Left diagram represents connections
#|   with significantly higher FA in patients a PANSS8-N score of 3 (the lowest possible
#|   score) a their follow-up session. Right hemisphere is the same, but with lower &RD
#|   as the correlate. Gyral abbreviations are given in @tbl-bn246-abbr. Subnetwork size
#|   and p-values are given in @tbl-nbs.
#| label: fig-nbs
from colormaps.utils import concat as cmaps_concat

plt.switch_backend("cairo")

side_title = dict(
    x=-0,
    y=0.5,
    rotation="vertical",
    rotation_mode="anchor",
    size=10,
    ha="center",
    va="bottom",
    color=Styles.Colors.dark[0],
)

fig = plt.figure(figsize=(7.5, 4.3), layout="constrained")
main, gutter = fig.subfigures(2, 1, height_ratios=[3.5, 0.8])
axs = main.subplots(1, 2)
params = ["FA", "RD"]
terms = ["recov"]
labels = {
    "recov": r"$\text{Follow-up PANSS8-N} = 3$",
    "slope": r"$\text{PANSS8-N Slope}$",
}
cms = {1: cmaps.ember.cut(0.4, "left"), -1: cmaps.cosmic.cut(0.2, "left")}
vcms = np.array([[1, -1], *([[-1, 1]] * 3)])
max_edge = topsy_nbs.max()
for y, x in np.ndindex(1, 2):
    ax = axs[x]
    ax.axis("off")
    plot_hierachical_connectome(
        topsy_nbs.sel(desc=f"{params[x]}n{terms[y]}"),
        nodes=atlases.bn246,
        ax=ax,
        emin=3,
        emax=max_edge,
        # ecmap=cmaps.vivid,
        ecmap=cms[vcms[y, x]],
        vcmap=cmaps_concat([cmaps.gray_5.cut(0.4, "right")] * 4).discrete(8),
        hierarchy=["Gyrus", "hemisphere"],
    )
    if x == 0:
        ax.set_title(labels[terms[y]], **(side_title | {"size": 12, "weight": "bold"}))
    if y == 0:
        ax.text(
            0.5,
            1,
            params[x],
            color=Styles.Colors.dark[0],
            size=14,
            ha="center",
            transform=ax.transAxes,
        )

grid = gutter.add_gridspec(1, 8)
cbar1 = gutter.add_subplot(grid[1:3])
add_colorbar(3, max_edge, cms[1], cbar1, outline=False, orientation="horizontal")
cbar1.set_xlabel("T-value", size=10, color=Styles.Colors.dark[0])
cbar1.set_title("Positive correlations", size=10)
cbar1.tick_params(axis="both", which="major", labelsize=8)

cbar2 = gutter.add_subplot(grid[5:7])
add_colorbar(3, max_edge, cms[-1], cbar2, outline=False, orientation="horizontal")
cbar2.set_xlabel("T-value", size=10, color=Styles.Colors.dark[0])
cbar2.tick_params(axis="both", which="major", labelsize=8)
cbar2.set_title("Negative correlations", size=10)

from io import BytesIO

from IPython.display import Image

buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")

plt.close()
%matplotlib inline
Image(buf.getbuffer())

---
label: fig-nbs

---
Networks associated with negative symptoms in the TOPSY dataset. Microstructural measures are averaged across sessions. Values for each connection were measured by sampling along the constituent streamlines. The recovery score is computed as in @fig-topsy-roi. In each network diagram, lines represent connections significantly correlated with the corresponding PANSS8-N derivative measure as determined using NBS (10,000 samples, $T_{thresh}=3$, FWER $<0.5$). Left diagram represents connections with significantly higher FA in patients a PANSS8-N score of 3 (the lowest possible score) a their follow-up session. Right hemisphere is the same, but with lower &RD as the correlate. Gyral abbreviations are given in @tbl-bn246-abbr. Subnetwork size and p-values are given in @tbl-nbs.

In [None]:
jhp_model_layout = BidsLayout("../jhp/derivatives/models-v0.1.0")

In [None]:
topsy_model_layout = BidsLayout("../topsy/derivatives/models-v0.1.0")

In [None]:
#| cell-offset: -1
#| fig-cap: 'Regions associated with negative symptoms. Displayed clusters significantly
#|   correlate with the metric of interest as determined using TFCE (10,000 samples,
#|   FWER $< 0.05$). Clusters are localized to the TBSS-derived FA skeleton and inflated
#|   for visualization. A: Measures in JHP patients compared with the SANS intercept,
#|   as described in @fig-jhp-roi. B: Effect of PANSS8-N recovery, as described in @fig-topsy-roi,
#|   on microstructure in TOPSY patients.'
#| label: fig-tbss
fig = plt.figure(figsize=(8, 7), facecolor="black")

descs = ["FA", "MD", "RD", "AD", "FA", "RD"]
panel_labels = ["A", "B"]

panels = fig.subfigures(2, 1, height_ratios=[2, 1])
panels[0].suptitle("JHP | SANS Intercept", **{**Styles.col_title, "color": "#ffffff"})
panels[1].suptitle(
    "TOPSY | Follow-up PANSS8-N = 3", **{**Styles.col_title, "color": "#ffffff"}
)
cm_types = {1: "autumn", -1: "winter"}
cmaps = [-1, 1, 1, 1, 1, -1]
for i in range(2):
    panels[i].text(
        0.1,
        0.95,
        panel_labels[i],
        **(Styles.panel_label | {"color": "white"}),
    )
axs = np.hstack([panels[0].subplots(4, 1), panels[1].subplots(2, 1)])

for y in range(6):
    if y < 4:
        panel = panels[0]
        data = jhp_model_layout.get(suffix="filled", desc=f"{descs[y]}nintercept").one
        bg = "../jhp/derivatives/tpl-fa/tpl-study/tpl-study_FA.nii.gz"
        cut_coords = np.r_[1:45:7j]
    else:
        panel = panels[1]
        data = topsy_model_layout.get(suffix="filled", desc=f"{descs[y]}nrecov").one
        bg = "../topsy/derivatives/tpl-FA/tpl-study/tpl-study_FA.nii.gz"
        cut_coords = np.r_[-30:30:7j]
    if y == 0:
        axs[y].text(0, 1, "L", color="white", transform=axs[y].transAxes)
        axs[y].text(1, 1, "R", color="white", ha="right", transform=axs[y].transAxes)
    plotting.plot_stat_map(
        data,
        bg,
        cut_coords=cut_coords,
        resampling_interpolation="nearest",
        display_mode="z",
        cmap=cm_types[cmaps[y]],
        symmetric_cbar=False,
        # vmin=vmin[0],
        # vmax=vmax[0],
        annotate=False,
        colorbar=False,
        axes=axs[y],
        figure=panel,
    )
    axs[y].axis("on")
    axs[y].get_yaxis().set_ticks([])
    axs[y].get_xaxis().set_visible(False)
    axs[y].set_ylabel(
        descs[y],
        **{
            **Styles.row_title,
            "color": "#ffffff",
            "ha": "center",
            "fontweight": "bold",
        },
    )

---
label: fig-tbss
cell-offset: -1

---
Regions associated with negative symptoms. Displayed clusters significantly correlate with the metric of interest as determined using TFCE (10,000 samples, FWER $< 0.05$). Clusters are localized to the TBSS-derived FA skeleton and inflated for visualization. A: Measures in JHP patients compared with the SANS intercept, as described in @fig-jhp-roi. B: Effect of PANSS8-N recovery, as described in @fig-topsy-roi, on microstructure in TOPSY patients.

## Tables

In [None]:
from IPython import display as d


def p_format(col):
    return (
        pl.when(pl.col(col) < 0.001)
        .then(pl.lit("< .001"))
        .otherwise(
            pl.col(col).round(3).round_sig_figs(2).cast(str).str.strip_chars_start("0")
        )
    )


sig_format = (
    pl.when(pl.col("pvalcorr") < 0.05)
    .then(pl.format(r"\textbf{{}}", p_format("pvalcorr")))
    .otherwise(p_format("pvalcorr"))
)


def proc_section_headings(latex: str):
    lines = latex.splitlines()
    pattern = re.compile(
        r"(\\multicolumn\{\d*\}\{[^}]*\}\{.*\})(\s*&\s*[Nn]a[Nn])+\s*\\\\"
    )
    return (
        "\n".join(
            pattern.sub(r"\\midrule \1 \\\\ \\midrule", line)
            if line.startswith(r"\multicolumn")
            else line
            for line in lines
        )
    )


def format_stats_table(df, label, caption):
    df_resid = df["df_resid"].cast(int)[0]
    return proc_section_headings(
        df.join(
            atlas_md.group_by("label").agg(pl.first("region", "hierarchy", "group")),
            on=["label", "hierarchy"],
        )
        .join(
            pl.DataFrame(
                {
                    "region": [
                        rf"\multicolumn{{8}}{{l}}{{\textit{{{s}}}}}"
                        for s in ["Global ROIs", "Regional ROIs", "Local ROIs"]
                    ],
                    "hierarchy": [1, 2, 3],
                    "desc": ["FA", "FA", "FA"],
                    "group": ["_", "_", "_"],
                }
            ),
            on=["region", "hierarchy", "desc", "group"],
            how="outer",
            coalesce=True,
        )
        .filter(pl.len().over("hierarchy") > 1)
        .fill_null(np.nan)
        .select(
            "region",
            pl.col.desc.replace({"L1": "AD"}),
            "hierarchy",
            "group",
            pl.col.statistic.round_sig_figs(2)
            .cast(str)
            .str.strip_chars_end("0")
            .alias(f"$T({df_resid})$"),
            sig_format.alias("$P_{corr}$"),
        )
        .to_pandas()
        .set_index(["hierarchy", "group", "region", "desc"])
        .unstack()
        .reorder_levels([1, 0], axis=1)
        .sort_index(axis=1, ascending=False)
        .reindex(["FA", "MD", "RD", "AD"], axis=1, level=0)
        .reset_index()
        .rename(columns={"": "Region"})
        .style
        # .apply(
        #     lambda s: np.full_like(
        #         s,
        #         "background-color: #f0f0f0 /* --rwrap */"
        #         if s.loc["group"].item() in ["peripheral", "callosal", "projection"]
        #         else None,
        #     ),
        #     axis=1,
        # )
        .hide()
        .hide(subset=["hierarchy", "group"], axis=1)
        .relabel_index(
            [
                "",
                r"$N_{{int}} \sim -FA$",
                r"$N_{{int}} \sim -FA$",
                r"$N_{{int}} \sim MD$",
                r"$N_{{int}} \sim MD$",
                r"$N_{{int}} \sim RD$",
                r"$N_{{int}} \sim RD$",
                r"$N_{{int}} \sim AD$",
                r"$N_{{int}} \sim AD$",
            ],
            axis=1,
            level=0,
        )
        .to_latex(
            column_format="rllllllll",
            hrules=True,
            multicol_align="c",
            convert_css=True,
            label=label,
            caption=caption,
        )
    )

In [None]:
print(
    topsy_nbs_stats.filter(pl.col.pval < 0.05)
    .select(
        # pl.lit("PANSS8-N == 3").alias("Score"),
        pl.col.desc.str.slice(0, 2).alias("Param"),
        # pl.col.i.cast(int).alias("Subnetwork"),
        pl.col.size.cast(int).alias("# Connections"),
        p_format("pval").alias("$P_{corr}$"),
    )
    .to_pandas()
    # .set_index(["Score", "Param"])
    .style.hide()
    .to_latex(
        column_format="rrll",
        hrules=True,
        multicol_align="c",
        convert_css=True,
        label="tbl-topsy-nbs-recovery",
        caption="Subnetworks with a significant association between DTI parameter and "
        "a follow-up PANSS8-N score of 3.",
    )
)

In [None]:
print(
    format_stats_table(
        jhp_stats.filter(
            pl.col.feature.is_in(["intercept"]),
            pl.col("pvalcorr").min().over("label") < 0.05,
            score="sans",
        ),
        label="tbl-jhp-intercept",
        caption=textwrap.dedent(
            """
            Microstructural measures versus SANS intercept in JHP patients. Intercept
            computed for each subject by fitting a first-order linear model. Statistics
            computed using unpaired, 1-tailed T-test after regressing age and sex. Bold
            results indicate significant results (only rows with at least one such
            result are shown). All p-values corrected using FDR with comparisons in the
            same hierarchical level.
            """
        ),
    )
)

In [None]:
print(
    format_stats_table(
        topsy_stats.filter(
            pl.col.feature.is_in(["recovery"]),
            pl.col("pvalcorr").min().over("label") < 0.05,
            score="PANSSN",
        ),
        label="tbl-topsy-hx",
        caption=textwrap.dedent(
            """
            Microstructural measures in TOPSY patients with a follow-up PANSS8-N score
            of 3 versus higher. Statistics computed using paired, 1-tailed T-tests after
            regressing age and sex. Bold results indicate significant results (only rows
            with at least one such result are shown). All p-values corrected using FDR
            with comparisons in the same hierarchical level.
            """
        ),
    )
)

In [None]:
print(
    atlases.bn246[["Gyrus Abbr", "Gyrus", "Lobe"]]
    .drop_duplicates()
    .style.hide()
    .to_latex(
        column_format="rll",
        hrules=True,
        convert_css=True,
        label="tbl-bn246-abbr",
        caption="Abbreviations of cortical regions.",
    )
)