In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.subplots()
# hack to remove hide globally installed libraries, which are the wrong R version
from rpy2 import robjects as ro
ro.r(".libPaths('/local/scratch/gt/lib/R/library')")

In [None]:
from rsbids import BidsLayout
import numpy as np
import pandas as pd
from numpy.polynomial import Polynomial
import tqdm
import seaborn as sns
import seaborn.objects as so
import scipy.stats as scs
import more_itertools as itx
import itertools as it
import xarray as xr
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lib.bidsarray import layout_map
from lib.plotting import move_legend_fig_to_ax, fig_to_numpy, annotate_axes
from pathlib import Path
from matplotlib import font_manager
from matplotlib.ticker import FuncFormatter, FormatStrFormatter
import templateflow.api as tflow
import nibabel as nb
import polars as pl
import polars.selectors as cs
import graph_tool.all as gt
import colormaps as cmaps
from lib.dataset import Dataset
from lib.bidsarray import layout_map
from lib.plotting import comparison_plot, add_colorbar, plot_hierachical_connectome
from dask.diagnostics import ProgressBar
from lib.seaborn_stats import MLEFit, Lme4CI, PolyCI, PearsonrAnnot
from lib import atlases
from lib.utils import concat_product

from styles import styles as Styles

%load_ext autoreload
%autoreload 2
%matplotlib inline
# plt.switch_backend("cairo")
plt.style.use("styles/manuscript.mplstyle")
so.Plot.config.theme.update(plt.rcParams)
font_dirs = [Path.home() / ".fonts"]
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

In [None]:
import rpy2
from rpy2 import robjects as ro
import rpy2.ipython.html
rpy2.ipython.html.init_printing()


In [None]:
from rpy2.robjects.packages import importr
rutils = importr("utils")
rbase = importr("base")
lme4 = importr("lme4")
rstats = importr("stats")
pbkrtest = importr("pbkrtest")
lmertest = importr("lmerTest")

### Initialize Datasets

In [None]:
jhp = (
    Dataset(".jhp.layout", "jhp", group_label="group")
    .add_phenotypes("jhp_metadata.yaml")
    .filter(
        ~pl.col("dx").is_in(["BPADI", "MDD", "Substance_induced", "Others"]),
        # ~pl.col("subject").is_in(["1018", "1026", "1047", "2041"]),
    )
)


topsy = (
    Dataset(".topsy.layout", "topsy")
    .add_phenotypes("topsy_metadata.yaml")
    .filter(pl.col("ddx").is_in([1, 2, 3, 9, 10]))
)



JHP_SESSIONS = {1: "Baseline", 2: "1yr", 3: "2yr", 4: "3yr"}
TOPSY_SESSIONS = {1: "Baseline", 2: "≈6 mo", 3: "1-2 yr"}

### Prepare clinical data

In [None]:
num_sessions = pl.col("session").unique().len().over("subject")


def do_fit(struct):
    struct = struct.struct.unnest().filter(pl.all_horizontal(~pl.col("*").is_null()))
    if struct["x"].len() == 1:
        coefs = [0.0, 0.0]
    else:
        coefs = Polynomial.fit(struct["x"], struct["y"], 1).convert().coef
        if coefs.shape[0] == 1:
            coefs = [coefs[0], 0.0]
    return {
        "intercept": (coefs[0]),
        "slope": (coefs[1]),
    }


def sel_fit(col):
    return (
        pl.when(pl.len() > 1)
        .then(
            pl.struct(pl.col("session").cast(int) - 1, col)
            .struct.rename_fields(["x", "y"])
            .map_elements(
                do_fit,
                return_dtype=pl.Struct({"intercept": pl.Float64, "slope": pl.Float64}),
            )
            .struct.rename_fields([f"{col}_intercept", f"{col}_slope"])
        )
        .alias(f"{col}_fit")
    )


def sel_recovery(col):
    return (pl.col(col).filter(pl.col("session") == "2").first() < 4).name.suffix(
        "_recovery"
    )

def sel_start_zero(col):
    return (pl.col(col).filter(pl.col("session") == "1").first() < 4).name.suffix(
        "_nostart"
    )


jhp_hx = (
    jhp.metadata.filter(pl.col("group") == "Patient")
    .group_by("subject")
    .agg(
        pl.len().alias("num_sessions"),
        pl.first("sex"),
        pl.mean("age"),
        pl.mean("sans", "saps").name.suffix("_mean"),
        pl.col("saps", "sans").filter(pl.col("session") == "1").first()
        .name.suffix("_baseline"),
        sel_fit("sans"),
        sel_fit("saps"),
    )
    .unnest("sans_fit", "saps_fit")
)

topsy_hx = (
    topsy.metadata.filter(pl.col("group") == "FEP", pl.col("session").cast(int) < 3)
    .group_by("subject")
    .agg(
        pl.len().alias("num_sessions"),
        pl.first("sex"),
        pl.mean("age"),
        pl.mean("PANSSP", "PANSSN", "SOFAS").name.suffix("_mean"),
        pl.col("PANSSP", "PANSSN", "SOFAS").filter(pl.col("session") == "1").first()
        .name.suffix("_baseline"),
        sel_fit("PANSSP"),
        sel_fit("PANSSN"),
        sel_fit("SOFAS"),
        sel_recovery("PANSSP"),
        sel_recovery("PANSSN"),
        sel_start_zero("PANSSN"),
    )
    .unnest(cs.matches(".*_fit"))
)

In [None]:
table = (
    topsy_hx.filter(pl.col("num_sessions") > 1)
    .group_by("sex", "PANSSN_recovery")
    .len()
    .pivot(index="sex", columns="PANSSN_recovery", values="len")
    .drop("sex")
    .to_numpy()
)
scs.chi2_contingency(table)

lm = smf.ols(
    "age ~ PANSSN_recovery",
    data=topsy_hx.filter(pl.col("num_sessions") > 1).with_columns(
        pl.col("PANSSN_recovery").cast(int)
    ),
).fit()
lm.model.exog

#### Demographics and clinical scores

In [None]:
fig = plt.figure(figsize=(6, 10), layout="constrained")
axs = fig.subplots(6, 2)
for i, x in enumerate(("PANSSP_recovery", "PANSSN_recovery")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        y="age",
        x=x,
        ax=axs[0, i],
        xlabel="Recovery?",
        ylabel="Age",
    )
for i, x in enumerate(("PANSSP_recovery", "PANSSN_recovery")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            color="sex",
            x=x,

        )
        .add(so.Bar(), so.Hist())
        .on(axs[1, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_slope", "PANSSN_slope")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            y=y,
            x="age",

        )
        .add(so.Dot())
        .on(axs[2, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_slope", "PANSSN_slope")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        x="sex",
        y=y,
        ax=axs[3, i],
        xlabel="Sex",
    )

for i, y in enumerate(("PANSSP_intercept", "PANSSN_intercept")):
    (
        so.Plot(
            topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
            y=y,
            x="age",

        )
        .add(so.Dot())
        .on(axs[4, i])
        .plot()
    )
for i, y in enumerate(("PANSSP_intercept", "PANSSN_intercept")):
    comparison_plot(
        topsy_hx.filter(pl.col("num_sessions") > 1).to_pandas(),
        x="sex",
        y=y,
        ax=axs[5, i],
        xlabel="Sex",
    )


#### Clinical Score Changes

In [None]:
num_sessions = pl.col("session").unique().len().over("subject")
topsy_df = (
    topsy.metadata.with_columns(pl.col.session.cast(int))
    .filter(
        pl.col("group") == "FEP",
        num_sessions > 1,
        ~pl.col("PANSSP", "PANSSN").is_null(),
        pl.col.session < 3,
    )
    .to_pandas()
)
jhp_df = (
    jhp.metadata.with_columns(pl.col.session.cast(int))
    .filter(
        pl.col("group") == "Patient",
        num_sessions > 1,
        ~pl.col("saps", "sans").is_null(),
    )
    .to_pandas()
)
with (ro.default_converter + ro.pandas2ri.converter).context():
#     lm1 = lme4.lmer("PANSSP ~ (1|subject)", data=topsy_df)
#     lm2 = lme4.lmer("PANSSP ~ session + (1|subject)", data=topsy_df)
#     res = pbkrtest.PBmodcomp(lm2, lm1)
#     print(res["test"])
    lm1 = lme4.lmer("saps ~ session + (1|subject)", data=jhp_df)
    lm2 = lme4.lmer("sans ~ session + (session|subject)", data=jhp_df)
    res = pbkrtest.PBmodcomp(lm2, lm1)
    # print(res["test"])
    res = lmertest.ranova(lm2)
res

In [None]:

with (ro.default_converter + ro.pandas2ri.converter).context():
    res = lmertest.ranova(lm2)
res

In [None]:
lme4.ranova

In [None]:
with (ro.default_converter + ro.pandas2ri.converter).context():

    print(ro.r("summary")(lm2))

In [None]:

#| test: foo

#| hello: world
jitter = so.Jitter(width=0.2, seed=1)
fig = plt.figure(figsize=(8, 5), layout="constrained")

axs = fig.subplots(2, 2)

variables = np.array([["PANSSP", "saps"], ["PANSSN", "sans"]])
datasets = [
    topsy.metadata.filter(pl.col("group") == "FEP", pl.col("session").cast(int) < 3),
    jhp.metadata.filter(pl.col("group") == "Patient"),
]
ses_labels = [TOPSY_SESSIONS, JHP_SESSIONS]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}
formulae = ["y ~ x + (1|group)", "y ~ x + (x|group)"]

for x, y in np.ndindex(2, 2):
    variable = variables[x, y]
    label = labels[variable]
    dataset = datasets[y]
    ax = axs[x, y]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            dataset.with_columns(
                pl.col("session").cast(int),#.replace(jhp_sessions),
                pl.when(num_sessions > 1).then(pl.col(variable)).name.prefix("multi_"),
                pl.when(num_sessions == 1)
                .then(pl.col(variable))
                .name.prefix("single_"),
            ).to_pandas(),
            x="session",
            y=f"multi_{variable}",
            group="subject",
        )
        .add(
            so.Line(color="#555555", linestyle="dashed", linewidth=1, alpha=0.3), jitter
        )
        .add(so.Line(linewidth=2), MLEFit())
        .add(so.Band(alpha=0.4), Lme4CI(formula=formulae[y], nsims=500))
        .add(so.Dot(color="#333333", edgewidth=0, alpha=0.5), jitter)
        .add(
            so.Dot(color="#333333", fill=False, alpha=0.8),
            so.Shift(x=-0.3),
            so.Jitter(width=0.2, seed=5),
            y=f"single_{variable}",
        )
        .scale(
            x=so.Continuous().tick(at=[1,2,3,4]).label(like=ses_labels[y].get),
            y=so.Continuous().tick(every=4),
        )
        .label(y=label, x="Session")
        .on(ax)
        .plot()
    )

axs[0, 0].set_title("TOPSY", **Styles.col_title)
axs[0, 1].set_title("JHP", **Styles.col_title)

---
label: fig-hx
cell-offset: -1

---
Baseline and follow-up clinical scores in early schizophrenia patients. At baseline, empty circles show subjects with no follow-ups. Filled circles connected by a line represent the same subject across multiple visits. No significant differences in scores were found between subjects with and without follow-up visits. Trendlines show a linear fixed effect model of parameter against session with random slopes and intercepts fit for every subject (only random intercepts for TOPSY). Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals and random effects 500 times. In the TOPSY dataset, the PANSS8-P score was significantly lower in the second session than the first (1000 perms, p < .001). No other symptom scores significantly changed across session.

### Gather white matter data

In [None]:
def get_wm_from_rois(path, wildcards, atlases):
    labels = list(
        it.chain.from_iterable(
            zip(it.repeat(i), range(np.max(atlas).astype(int) + 1))
            for i, atlas in enumerate(atlases)
        )
    )
    nlabels = len(labels)

    @layout_map(parallel=True, dims={"roi": nlabels}, dtype=float)
    def inner(path):
        data = nb.load(path).get_fdata()
        result = np.empty((nlabels,))
        try:
            for i in range(0, nlabels):
                atlas_ix, ix = labels[i]
                if ix == 0:
                    result[i] = 0
                    continue
                atlas = atlases[atlas_ix]

                result[i] = np.mean(data[atlas == ix])
            return result
        except:
            print(path)
            raise

    return inner(path, wildcards)


def run_roi_sampling(layout, jhu_atlas, lobe_atlas, skeleton_dims):
    mean_skeleton = np.asanyarray(
        layout_map(parallel=True, dims=skeleton_dims, dtype=float)(
            lambda p: nb.load(p).get_fdata()
        )(layout.get(suffix="skeletonized", desc="FA"), ["subject", "session"]).mean(
            ["subject", "session"]
        )
        > 0
    )
    jhu_atlas, lobe_atlas = (
        np.where(mean_skeleton, nb.load(atlas).get_fdata(), 0)
        for atlas in (jhu_atlas, lobe_atlas)
    )
    lobe_mask = np.where(jhu_atlas == 0, lobe_atlas, 0)
    return get_wm_from_rois(
        layout.get(suffix="skeletonized", desc=["FA", "MD", "RD", "L1"]),
        ["subject", "session", "desc"],
        atlases=[lobe_mask, jhu_atlas],
    )

from dask.diagnostics import ProgressBar

topsy_wm_sampled = run_roi_sampling(
    topsy.layout.get(suffix="skeletonized"),
    jhu_atlas="../topsy/code/jhp-atlas/atlas.nii.gz",
    lobe_atlas="../topsy/code/jhp-atlas/lobe-atlas.nii.gz",
    skeleton_dims={"x": 78, "y": 109, "z": 79}
)
jhp_wm_sampled = run_roi_sampling(
    jhp.layout.get(suffix="skeletonized"),
    jhu_atlas="../jhp/derivatives/atlases/atlas.nii.gz",
    lobe_atlas="../jhp/derivatives/atlases/lobe-atlas.nii.gz",
    skeleton_dims={"x": 248, "y": 295, "z": 93}
)


In [None]:

with ProgressBar():
    topsy_wm_sampled.to_netcdf("topsy_wm_sampled.nc")
    jhp_wm_sampled.to_netcdf("jhp_wm_sampled.nc")

### Gather all the surface sample files

In [None]:
def get_hems(img):
    l = img[img.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project(0).flatten()
    r = img[img.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project(0).flatten()
    return l, r



@layout_map(parallel=True, dtype=float, dims={"vertex": 64984})
def load_data(path):
    img = cp.load(path)
    lh, rh = get_hems(img)
    data = np.full((lh.shape[0] + rh.shape[0]), np.NaN)
    bound = lh.shape[0]
    data[:bound] = lh
    data[bound:] = rh
    return data

from dask.diagnostics import ProgressBar
topsy_surface = load_data(
    topsy.layout.get(suffix=["curv", "thickness"], den="32k"),
    ["subject", "session", "suffix"]
)
jhp_surface = load_data(
    jhp.layout.get(suffix=["curv", "thickness"], den="32k"),
    ["subject", "session", "suffix"],
).to_dataset(name="surface")
with ProgressBar():
    jhp_surface.to_netcdf("jhp_surface.nc")
    topsy_surface.to_netcdf("topsy_surface.nc")

In [None]:
dkmd = pd.read_csv("atlas-dkt_labels.tsv", sep="\t")


def do_sampling(ds, layout):

    @layout_map(
        parallel=True, dims={"param": ds["param"], "roi": dkmd["label"]}, dtype=float
    )
    def sample_dk(path):
        dknii = cp.load(path)
        dkatlas = np.empty(32492 * 2)
        dkatlas[:32492] = (
            dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project().ravel()
        )
        dkatlas[32492:] = (
            dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project().ravel()
        )
        result = np.empty((len(ds["param"]), len(dkmd)))
        x = ds.sel(
            subject=path.entities["subject"], session=path.entities["session"]
        )
        for (i, param), (j, label) in it.product(
            enumerate(ds["param"]), enumerate(dkmd["label"])
        ):
            result[i, j] = np.mean(x.sel(param=param)[dkatlas == label])
        return result

    return (
        sample_dk(
            layout.get(
                suffix="dparc",
                subject=ds["subject"].data,
                atlas="dkt",
                space="fsLR",
                den="32k",
            ),
            wildcards=["subject", "session"],
        )
        .to_dataset(name="dk")
        .merge(dkmd.rename(columns={"label": "roi"}).set_index("roi").to_xarray())
    )


topsy_surface = xr.open_dataarray("topsy_surface.nc", chunks={})
jhp_surface = xr.open_dataarray("jhp_surface.nc", chunks={})
topsy_sampled = do_sampling(topsy_surface.rename(suffix="param"), topsy.layout)
jhp_sampled = do_sampling(jhp_surface.rename(suffix="param"), jhp.layout)
with ProgressBar():
    jhp_sampled.to_netcdf("jhp_sampled.nc")
    topsy_sampled.to_netcdf("topsy_sampled.nc")

## Investigate!

In [None]:
topsy_wm_sampled = xr.open_dataarray("topsy_wm_sampled.nc", chunks={})
jhp_wm_sampled = xr.open_dataarray("jhp_wm_sampled.nc", chunks={})

In [None]:
atlas_md = pl.read_csv("atlas-study_labels.csv")

atlas_filters = [
    ~pl.col("label").is_in(
        [
            "Med",
            "Po",
            "Mb",
            "FTS",
        ]
    ),
    pl.col("group") != "cerebellar"
]

In [None]:
atlas_md.group_by("label").first().filter(atlas_filters).group_by("group").len()

In [None]:
group_indices = dict(
    zip(*atlas_md.filter(pl.col("group").is_in(["core", "global"]))[["name", "index"]])
)


def get_index(group: str):
    return pl.lit(group_indices[group], dtype=pl.Int64)


def prepare_wm_rois(df, index):
    df = df.join(atlas_md.rename({"index": "roi"}), on="roi").filter(*atlas_filters)
    periph_rois = pl.col("group") == "peripheral"
    core_rois = pl.col("group").is_in(
        ["projection", "callosal", "limbic", "association"]
    )
    return (
        pl.concat(
            [
                df[[*index, "roi", "data"]],
                # # Global
                df.group_by(index).agg(
                    get_index("global").alias("roi"), pl.mean("data")
                ),
                # # Core
                df.filter(core_rois)
                .group_by(index)
                .agg(get_index("core").alias("roi"), pl.mean("data")),
                # # Peripheral
                df.filter(periph_rois)
                .group_by(index)
                .agg(get_index("peripheral").alias("roi"), pl.mean("data")),
                # # Core groups
                df.filter(core_rois)
                .group_by(*index, "group")
                .agg(pl.mean("data"))
                .with_columns(pl.col("group").replace(group_indices).cast(pl.Int64))
                .rename({"group": "roi"}),
            ],
        )
        .join(atlas_md.rename({"index": "roi"}), on="roi")
        .group_by(*index, "label")
        .agg(
            pl.col(
                "region",
                "hierarchy",
            ).first(),
            pl.col("data").mean(),
        )
    )

In [None]:
num_sessions = pl.col("session").unique().len().over("subject")


def all_sessions(da, hx):
    return (
        prepare_wm_rois(
            da.to_dataset(name="data")
            .to_dataframe()
            .dropna()
            .reset_index()
            .pipe(pl.from_pandas),
            ["subject", "session", "desc"],
        )
        .join(hx, on=["subject", "session"])
        .with_columns(num_sessions=num_sessions)
    )


topsy_df = all_sessions(
    topsy_wm_sampled,
    topsy.metadata.filter(
        pl.col("group").is_in(["FEP", "HC"]), pl.col("session").cast(int) < 3
    ),
)
jhp_df = all_sessions(
    jhp_wm_sampled,
    jhp.metadata.filter(pl.col("group").is_in(["Patient", "HC"])),
)

In [None]:
jitter = so.Jitter(seed=3009)
fig = plt.figure(figsize=(8, 10), layout="constrained")
axs = fig.subplots(4, 2)

variables = np.array(["FA", "MD", "RD", "L1"])
dfs = [
    topsy_df,
    jhp_df,
]
ses_labels = [TOPSY_SESSIONS, JHP_SESSIONS]
orders = [["HC", "FEP"], ["HC", "Patient"]]
yscales = [
    so.Continuous(),
    *([so.Continuous().label(like=lambda x, _: f"{x*1000:.2f}")] * 3)
]
units = [
    "",
    *([r"$\frac{nm^2}{ms}$"] * 3)
]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}

for x, y in np.ndindex(4, 2):
    variable = variables[x]
    label = labels.get(variable, variable)
    labely  = f"{label} ({units[x]})" if units[x] else label
    df = dfs[y]
    ax = axs[x, y]
    num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            df.filter(
                pl.col.label == "WM", pl.col.desc == variable, pl.col.num_sessions > 1
            )
            .with_columns(pl.col.session.cast(int))
            .to_pandas(),
            x="session",
            y="data",
            group="subject",
            color="group",
        )
        # .add(so.Line(linestyle="dashed", linewidth=0.5, alpha=0.5), jitter, legend=False)
        .add(so.Line(linewidth=2), MLEFit(), legend=False)
        .add(so.Band(), Lme4CI(nsims=500), legend=False)
        .add(so.Dot(edgewidth=0, alpha=0.5, fill=None), jitter, legend=False)
        .scale(
            x=so.Continuous().tick(at=[1, 2, 3, 4]).label(like=ses_labels[y].get),
            y=yscales[x],
            color=so.Nominal(order=orders[y])
        )
        .label(y=None if y > 0 else labely, x=None if x < 3 else "Session")
        .on(ax)
        .plot()
    )
axs[0, 0].set_title("TOPSY", **Styles.col_title)
axs[0, 1].set_title("JHP", **Styles.col_title)

---
label: fig-longt
cell-offset: -1

---
Global longitudinal changes of white matter microstructure in early schizophrenia patients. Trendlines show a linear fixed effect model of parameter against session with random intercepts fit for every subject. Shaded bands show a 95% CI computed with parametric bootstrapping resampling residuals and random effects 500 times. No significant differences were found between the slopes of HCs and patients in either dataset for any of the parameters measured. In the JHP sample, fitting random slopes to each subject did not significantly improve the fit of the model (not tested in TOPSY because each subject had only two time points).

### Correlations with Symptoms

In [None]:
def avg_sessions(da, hx):
    return prepare_wm_rois(
        pl.from_pandas(da.mean("session").to_dataframe(name="data").reset_index()),
        ["subject", "desc"]
    ).join(hx, on="subject")


topsy_df = avg_sessions(topsy_wm_sampled, topsy_hx)
jhp_df = avg_sessions(jhp_wm_sampled, jhp_hx)

In [None]:
def _1tail(pval, stat, desc, param):
    val = 1
    if desc in {"FA", "thickness"}:
        val *= -1
    if "recovery" in param:
        val *= -1
    val *= np.sign(stat)
    return (min(0, val) * -1) + (pval * val / 2)


def do_stats(col):
    def inner(data):
        df = data.to_pandas().apply(pd.Series)
        model = smf.ols(
            f"data ~ {col} + age + sex", data=df
        ).fit()

        token = itx.one(n for n in model.model.exog_names if col in n)
        contr = model.t_test(token)
        return {
            "intercept": model.params.loc["Intercept"],
            "beta": model.params.loc[token],
            "pval": _1tail(contr.pvalue, contr.statistic, data[0]["desc"], col),
            "statistic": contr.statistic,
            "nobs": model.nobs,
            "df_model": model.df_model,
            "df_resid": model.df_resid,
        }

    dtype = pl.Struct(
        {
            "intercept": pl.Float64,
            "beta": pl.Float64,
            "pval": pl.Float64,
            "statistic": pl.Float64,
            "nobs": pl.Float64,
            "df_model": pl.Float64,
            "df_resid": pl.Float64,
        }
    )
    return (
        pl.struct("desc", "data", col, "age", "sex")
        .map_elements(inner, return_dtype=dtype)
        .alias(f"{col}_stats")
    )


def pearsonr(col):
    def inner(data):
        result = scs.pearsonr(
            np.asarray(data.struct[col]), np.asarray(data.struct["data"])
        )
        return result.statistic

    return (
        pl.struct("data", col)
        .filter(~pl.col(col).is_null())
        .map_elements(inner, return_dtype=pl.Float64)
        .alias(f"{col}_pearsonr")
    )


def get_pvals(col):
    return do_stats(col).name.prefix_fields(f"{col}_")  # .struct[f"{col}_pval"]


def get_all_stats(df, scores, features):
    return (
        df.filter(pl.col("num_sessions") > 1)
        .group_by("label", "desc")
        .agg(
            *(
                do_stats(f"{score}_{suffix}")
                for score, suffix in it.product(scores, features)
            ),
            *(
                pearsonr(f"{score}_{suffix}")
                for score, suffix in it.product(scores, features)
            ),
            pl.first("hierarchy"),
        )
        .melt(["label", "desc", "hierarchy"], cs.matches(".*_stats"), "param", "stats")
        .with_columns(
            pl.col("param").str.split("_").list.to_struct(fields=["score", "feature"])
        )
        .unnest("param", "stats")
        .with_columns(
            pl.col("pval")
            .map_elements(
                lambda x: pl.Series(scs.false_discovery_control(x)),
                return_dtype=pl.List(pl.Float64),
            )
            .over("hierarchy", "desc", "score", "feature")
            .name.suffix("corr")
        )
    )


topsy_stats = get_all_stats(
    topsy_df,
    ["PANSSP", "PANSSN"],
    ["recovery"],
    # The rest are not significant
    # ["baseline", "mean", "slope", "intercept", "recovery"],
)
jhp_stats = get_all_stats(
    jhp_df,
    ["saps", "sans"],
    ["baseline", "mean", "intercept"],
    # The rest are not significant
    # ["baseline", "mean", "slope", "intercept"],
)

In [None]:
sig = (
    jhp_stats.filter(pl.col("pvalcorr") < 0.05)
    # .with_columns(parameter=pl.concat_str("score", "feature", separator="_"))
    .join(
        # atlas_md.rename({"index": "roi"}),
        atlas_md.group_by("label").agg(pl.first("region")),
        on="label",
    )
    .pivot(
        values="pvalcorr",
        index=["region", "desc", "hierarchy", "feature"],
        columns="score",
    )
)
# sig.write_clipboard()

In [None]:
# sig.write_clipboard()

In [None]:
df_r = df.filter(
    pl.col("num_sessions") > 1,
    pl.col("label") == "WM",
    pl.col("desc") == "L1",
).group_by("label", "subject").agg(pl.col("*").exclude("data").first(), pl.col("data").mean())
lm1 = smf.ols("data ~ Q('PANSSN_recovery')", data=df_r.to_pandas()).fit()
lm2 = smf.ols("data ~ Q('PANSSN_recovery') + age", data=df_r.to_pandas()).fit()
lm2.compare_f_test(lm1)

In [None]:
fig = plt.figure(figsize=(8, 10), layout="constrained")
axs = fig.subfigures(4, 1)

variables = np.array(["FA", "MD", "RD", "L1"])
yscales = [
    so.Continuous(),
    *([so.Continuous().label(like=lambda x, _: f"{x*1000:.2f}")] * 3),
]
units = ["", *([r"$\frac{nm^2}{ms}$"] * 3)]

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
    "sans": "SANS",
    "saps": "SAPS",
}

for y, x in np.ndindex(4, 1):
    variable = variables[y]
    label = labels.get(variable, variable)
    labely = f"{label} ({units[y]})" if units[y] else label
    ax = axs[y]
    # num_sessions = pl.col("session").unique().len().over("subject")
    (
        so.Plot(
            jhp_df.filter(pl.col("hierarchy") < 2, pl.col.desc == variable)
            .join(
                jhp_stats.filter(
                    pl.col.desc == variable, feature="intercept", score="sans"
                )[["label", "pvalcorr", "desc"]],
                on="label",
            )
            .with_columns(sig=pl.col.pvalcorr < 0.05),
            y="data",
            x="sans_intercept",
        )
        .facet(
            col="region",
            order=["White Matter", "Core White Matter", "Peripheral White Matter"],
        )
        .share(y=False)
        .add(so.Dot(color="#555555"))
        .add(so.Line(), so.PolyFit(1))
        .add(so.Band(), PolyCI(ci=95), legend=False)
        .add(
            so.Text({"ha": "right" if variable == "FA" else "left"}),
            PearsonrAnnot("upper right" if variable == "FA" else None),
        )
        .label(
            title="" if y > 0 else None, x=None if y < 3 else "SANS Intercept", y=labely
        )
        .scale(y=yscales[y], color=["#88c", "#555"], )
        .on(ax)
        .plot()
    )
    if y == 0:
        for i in [2, 3]:
            ax.get_children()[i].get_children()[2].set_facecolor((0, 0, 0, 0.2))
            ax.get_children()[i].get_children()[1].set_color((0, 0, 0, 0.8))
            ax.get_children()[i].get_children()[1].set_linestyle(":")

---
label: fig-jhp-roi
cell-offset: -1

---
Correlation between microstructural parameters and SANS intercept. Microstructural measurements are averaged across all sessions for each subject. The SANS intercept was computed using a first order linear model for each subject, with the baseline scan as time 0. Shaded bands show 95% CI computed with nonparametric bootstrap paired resampling with 10,000 permutations. Relationships were tested with a linear model with age and sex and covariates. Solid lines represent statistically significant relationships, dashed lines are nonsignificant. T-values and P-values are shown in @tbl-XXX.

In [None]:
dir(ax.get_children()[2])
ax.get_children()[2].get_children()[2].get_facecolor()

In [None]:
(
    jhp_df.filter(pl.col("hierarchy") < 2, pl.col.desc == variable)
    .join(
        jhp_stats.filter(
            pl.col.desc == variable, feature="intercept", score="sans"
        )[["label", "pvalcorr", "desc"]],
        on="label",
    )
    .with_columns(sig=pl.col.pvalcorr < 0.05)
)

---
label: fig-jhp-rois
cell-offset: -1

---
Correlations between microstructural measures and 

### Heatmap of other regions

In [None]:
from nilearn import plotting
from scipy.ndimage import binary_dilation

img = nb.load(jhp.layout.get(suffix="skeletonized", desc="FA")[0])
skeleton_mask = binary_dilation(img.get_fdata() > 0)

jhu_atlas = np.where(
    skeleton_mask,
    nb.load("../jhp/derivatives/atlases/atlas.nii.gz").get_fdata(),
    0,
)
lobe_atlas = np.where(
    img.get_fdata() > 0, nb.load("../jhp/derivatives/atlases/lobe-atlas.nii.gz").get_fdata(), 0
)
# core_mask = (jhu_atlas > 0).astype(int)
# periph_mask = ((lobe_atlas > 0) & (jhu_atlas == 0)).astype(int)
lobe_mask = np.where(jhu_atlas == 0, lobe_atlas, 0)


In [None]:
def get_level2_atlas():
    tract_name_to_id = dict(zip(*atlas_md.filter(group="core")[["name", "index"]]))
    tract_id, jhu_id = atlas_md.filter(
        pl.col.hierarchy == 3, ~pl.col.group.is_in({"unclassified", "cerebellar"})
    ).select(
        pl.col.group.replace(tract_name_to_id).cast(int),
        pl.col.atlas_id,
    )
    atlas_vals = np.zeros(jhu_atlas.max().astype(int) + 1)
    np.put(atlas_vals, jhu_id, tract_id)
    return np.asarray(lobe_mask + atlas_vals[jhu_atlas.astype(int)], dtype=int)
    
def project_level3(df):
    atlas_ids, stats = atlas_md.join(df, on="label").filter(
        pl.col.pvalcorr < 0.05,
        pl.col.hierarchy == 3,
    )[["atlas_id", "statistic"]]
    atlas_vals = np.zeros(jhu_atlas.max().astype(int) + 1)
    np.put(atlas_vals, atlas_ids, stats)
    return atlas_vals[jhu_atlas.astype(int)]


def project_level2(df):
    atlas_ids, stats = atlas_md.join(df, on="label").filter(
        pl.col.pvalcorr < 0.05,
        pl.col.hierarchy == 2,
    )[["index", "statistic"]]
    atlas = get_level2_atlas()
    atlas_vals = np.zeros(atlas.max() + 1)
    np.put(atlas_vals, atlas_ids, stats)

    return atlas_vals[atlas]


In [None]:
vmin, vmax = jhp_stats.filter(
    pl.col.pvalcorr < 0.05,
    pl.col.hierarchy.is_in([2, 3]),
    pl.col.feature == "intercept",
).select(min=pl.min("statistic"), max=pl.max("statistic"))

fig = plt.figure(figsize=(8, 8), facecolor="black")
bg = "../jhp/derivatives/tpl-fa/tpl-study/tpl-study_FA.nii.gz"

descs = ["MD", "RD", "L1"]
panel_labels = ["A", "B"]

*panels, gutter = fig.subfigures(3, 1, height_ratios=[10, 10, 0.5])
panels[0].suptitle(
    "Peripheral and Core Groups", **{**Styles.col_title, "color": "#ffffff"}
)
panels[1].suptitle("JHU ROIs", **{**Styles.col_title, "color": "#ffffff"})
for i, project in enumerate([project_level2, project_level3]):
    panel = panels[i]
    panel.text(
        0.1,
        0.95,
        panel_labels[i],
        **(Styles.panel_label | {"color": "white"}),
    )
    axs = panel.subplots(3, 1)

    for y in range(3):
        if y == 0:
            axs[y].text(0, 1, "L", color="white")
            axs[y].text(1, 1, "R", color="white", ha="right")
        param_map = nb.Nifti1Image(
            project(jhp_stats.filter(desc=descs[y], feature="intercept")),
            img.affine,
            img.header,
        )
        plotting.plot_stat_map(
            param_map,
            bg,
            cut_coords=np.r_[1:45:7j],
            resampling_interpolation="nearest",
            display_mode="z",
            cmap="autumn",
            symmetric_cbar=False,
            vmin=vmin[0],
            vmax=vmax[0],
            annotate=False,
            colorbar=False,
            axes=axs[y],
            figure=panel,
        )
        axs[y].axis("on")
        axs[y].get_yaxis().set_ticks([])
        axs[y].get_xaxis().set_visible(False)
        axs[y].set_ylabel(
            descs[y],
            **{
                **Styles.row_title,
                "color": "#ffffff",
                "ha": "center",
                "fontweight": "bold",
            },
        )
guttergrid = gutter.add_gridspec(1, 3, width_ratios=[10, 80, 10])
cbar = gutter.add_subplot(guttergrid[1])
cb = add_colorbar(
    vmin, vmax, ax=cbar, cmap="autumn", orientation="horizontal", outline=False
)
cbar.xaxis.label.set_color("white")
cbar.tick_params(axis="x", colors="white", labelsize=10)
cbar.set_xlabel("T-value", color="white", size=12)

---
label: fig-jhp-heatmap
cell-offset: -1

---
Correlations between microstructural parameters and SANS intercept. Microstructural measures and SANS intercpets were computed as in @fig-jhp-roi. Relationships were tested with a linear model with age and sex and covariates. Significant ROIs are coloured according to their T-value. A and B show two nested hierarchical layers, with B at a finer resolution. Comparisons within each layer were corrected for multiple comparisons using FDR. T-values and P-values are shown in @tbl-XXX.

### TOPSY correlations

In [None]:
vmin, vmax = topsy_stats.filter(
    pl.col.pvalcorr < 0.05,
    pl.col.hierarchy.is_in([2, 3]),
    pl.col.feature == "recovery",
).select(min=pl.min("statistic"), max=pl.max("statistic"))
param_map = nb.Nifti1Image(
    project_level2(topsy_stats.filter(feature="recovery", desc="FA", score="PANSSN")),
    img.affine,
    img.header,
)
bg = "../jhp/derivatives/tpl-fa/tpl-study/tpl-study_FA.nii.gz"
cuts = np.array([[23, 30], [37, 45]])
imgs = np.zeros((2, 2), dtype=object)
for y, x in np.ndindex(2, 2):
    fig = plt.figure(figsize=(5, 5), facecolor="black")
    ax = fig.subplots(1, 1)
    plotting.plot_stat_map(
        param_map,
        bg,
        cut_coords=[cuts[y, x]],
        resampling_interpolation="nearest",
        display_mode="z",
        cmap="autumn",
        symmetric_cbar=False,
        vmin=vmin[0],
        vmax=vmax[0],
        annotate=False,
        colorbar=False,
        axes=ax,
    )
    imgs[y, x] = fig_to_numpy(fig)
    plt.close()
topsy_tmap = np.vstack([np.hstack([a[150:1300, 300:1200] for a in r]) for r in imgs])

In [None]:
fig = plt.figure(figsize=(8, 6), layout="constrained")
grid = fig.add_gridspec(3, 3)

slots = [
    grid[0, 0],
    grid[0, 1],
    grid[0, 2],
    grid[1, 0],
    grid[2, 0],
]

sig_topsy = topsy_df.join(
    topsy_stats[["desc", "label", "pvalcorr"]], on=["label", "desc"]
).filter(pl.col.pvalcorr < 0.05, pl.col.num_sessions > 1)

regions = (
    sig_topsy.group_by("region")
    .agg(pl.first("hierarchy"))
    .sort("hierarchy", "region")["region"]
    .to_numpy()
)

labels = {
    "PANSSP": "PANSS8-P",
    "PANSSN": "PANSS8-N",
}

panels = ["A", "B", None, "C", None]

for i in range(5):
    region = regions[i]
    ax = fig.add_subplot(slots[i])
    ax.set_title(region)
    ax.set_xticks([False, True], labels=[">3", "3"])
    if (panel := panels[i]) is not None:
        ax.text(
            -0.25,
            1,
            panel,
            transform=ax.transAxes,
            **Styles.panel_label,
        )

    comparison_plot(
        sig_topsy.filter(region=region), x="PANSSN_recovery", y="data", ax=ax
    )
    if i in {4, 1, 2}:
        ax.set_xlabel("Follow-up PANSS8-N Score", size=10)
    else:
        ax.set(xlabel=None)
    if i in {0, 3, 4}:
        ax.set_ylabel("FA")
    else:
        ax.set_ylabel(None)


fig.patches.append(
    plt.Rectangle(
        (0.43, 0.02), 0.57, 0.6, color="black", transform=fig.transFigure, zorder=-1
    )
)
atlas_panel = fig.add_subfigure(grid[1:, 1:])
atlas_panel.text(
    -0.05,
    1,
    "D",
    **Styles.panel_label,
)
axs = atlas_panel.subplots(1, 2, width_ratios=[15, 1])
axs[0].imshow(topsy_tmap)
axs[0].axis("off")
axs[0].set_position(axs[0].get_position().expanded(1.3, 1.3))
axs[0].text(
    0, 0.45, "L", color="white", size=14, fontweight="bold", transform=axs[0].transAxes
)
axs[0].text(
    1,
    0.45,
    "R",
    ha="right",
    color="white",
    size=14,
    fontweight="bold",
    transform=axs[0].transAxes,
)

add_colorbar(vmin, vmax, ax=axs[1], cmap="autumn", outline=False)
axs[1].set_ylabel("T-value")
axs[1].yaxis.label.set_color("white")
axs[1].yaxis.set_major_formatter(FormatStrFormatter("%0.02f"))
axs[1].tick_params(axis="y", colors="white", labelsize=10)

---
label: fig-topsy-roi
cell-offset: -1

---
Correlations between microstructural parameters and the PANSS8-N follow-up score. Microstructural measures are computed as in @fig-jhp-roi. Subjects are grouped based on whether their PANSS8-N score at their follow-up session was equal to 3, the lowest possible socre. Relationships were tested with a linear model with age and sex and covariates. All comparisons shown are signficant. A and B and C show ROIs from different nested hierarchical layers at successively higher resolutions. D shows the location of the regions in C with their T-values. Comparisons within each layer were corrected for multiple comparisons using FDR. T-values and P-values are shown in @tbl-XXX.

In [None]:

topsy_modelling = BidsLayout("../topsy/derivatives/models-v0.1.0")

In [None]:
import h5py


@layout_map(dims={"src": atlases.bn246["Label ID"], "dest": atlases.bn246["Label ID"]}, dtype=int)
def load_h5_nbs(file):
    with h5py.File(file, "r") as f:
        try:
            networks =  f["nbs/con_mat"][:]
            # sizes = np.sum(networks.reshape(-1, networks.shape[-1]), axis=0)
            # labels = np.argsort(sizes)[::-1] + 1
            # return np.sum(networks * labels, axis=-1)
            sig =  np.sum(networks, axis=-1)
        except IndexError:
            sig = 0
        return f["nbs/test_stat"][:] * sig


descs = [
    "".join(x)
    for x in it.product(
        ["AD", "FA", "MD", "RD"], ["p", "n"], ["intercept", "recov", "slope"]
    )
]
topsy_nbs = load_h5_nbs(
    topsy_modelling.get(suffix="nbs", label=["pos", "neg"], desc=descs), ["desc"]
)

In [None]:
np.unique(topsy_nbs)

In [None]:
models = (
    (topsy_nbs > 0)
    .sum(["src", "dest"])
    .to_dataframe(name="count")
    .reset_index()
    .pipe(pl.from_pandas)
    .with_columns(
        param=pl.col.desc.str.slice(0, 2),
        score=pl.col.desc.str.slice(2, 1).replace({"n": "PANSSN", "p": "PANSSP"}),
        term=pl.col.desc.str.slice(3),
    )
)

In [None]:
atlases.bn246

In [None]:
from colormaps.utils import concat as cmaps_concat

plt.switch_backend("cairo")

side_title = dict(
    x=-0.1,
    y=0.5,
    rotation="vertical",
    rotation_mode="anchor",
    size=10,
    ha="center",
    va="bottom",
    color=Styles.Colors.dark[0],
)

fig = plt.figure(figsize=(5.3, 10))
main, gutter = fig.subfigures(1, 2, width_ratios=[5, 0.3])
axs = main.subplots(4, 2)
params = ["FA", "MD", "RD", "AD"]
terms = ["recov", "slope"]
labels = {
    "recov": r"$\text{Follow-up PANSS8-N} = 3$",
    "slope": r"$\text{PANSS8-N Slope}$",
}
cms = {1: cmaps.ember.cut(0.4, "left"), -1: cmaps.cosmic.cut(0.2, "left")}
vcms = np.array([[1, -1], *([[-1, 1]] * 3)])
max_edge = topsy_nbs.max()
for y, x in np.ndindex(4, 2):
    ax = axs[y, x]
    ax.axis("off")
    plot_hierachical_connectome(
        topsy_nbs.sel(desc=f"{params[y]}n{terms[x]}"),
        nodes=atlases.bn246,
        ax=ax,
        emin=3,
        emax=max_edge,
        # ecmap=cmaps.vivid,
        ecmap=cms[vcms[y, x]],
        vcmap=cmaps_concat([cmaps.gray_5.cut(0.4, "right")] * 4).discrete(8),
        hierarchy=["Gyrus", "hemisphere"],
    )
    if x == 0:
        ax.set_title(params[y], **(side_title | {"size": 12, "weight": "bold"}))
    if y == 0:
        ax.text(
            0.5,
            1.1,
            labels[terms[x]],
            color=Styles.Colors.dark[0],
            size=10,
            ha="center",
            transform=ax.transAxes,
        )

grid = gutter.add_gridspec(6, 1)
cbar1 = gutter.add_subplot(grid[1:3])
add_colorbar(3, max_edge, cms[1], cbar1, outline=False)
cbar1.set_ylabel("T-value", size=10, color=Styles.Colors.dark[0])
cbar1.set_title(
    "Positive correlations",
    **side_title,
)
cbar1.tick_params(axis="both", which="major", labelsize=8)

cbar2 = gutter.add_subplot(grid[3:5])
add_colorbar(3, max_edge, cms[-1], cbar2, outline=False)
cbar2.set_ylabel("T-value", size=10, color=Styles.Colors.dark[0])
cbar2.tick_params(axis="both", which="major", labelsize=8)
cbar2.set_title(
    "Negative correlations",
    **side_title,
)

%matplotlib inline
plt.figure()

---
label: fig-nbs
cell-offset: -1

---
Networks associated with negative symptoms in the TOPSY dataset. Microstructural measures are averaged across sessions. Values for each connection were measured by sampling along the constituent streamlines. The recovery score is computed as in @fig-topsy-roi. The PANSS8-N slope was computed for each subject by fitting a first-order model to the symptom measurements across sessions. In each network diagram, lines represent connections significantly correlated with the corresponding PANSS8-N derivative measure as determined using NBS (10,000 samples, $T_{thresh}=3$, $p<0.5$). For instance, the top-left diagram represents connections with significantly higher FA in patients a PANSS8-N score of 3 (the lowest possible score) a their follow-up session. Gyral abbreviations are given in @tbl-XXX