In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.subplots()

In [None]:
from rsbids import BidsLayout
import ciftipy as cp
import numpy as np
import pandas as pd
from brainspace.plotting import plot_hemispheres
from brainspace.mesh.mesh_io import read_surface
from brainstat.datasets import fetch_template_surface, fetch_mask
from lib.mesh import mesh_smooth
import tqdm
import seaborn as sns
import scipy.stats as scs
import itertools as it
import xarray as xr
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lib.bidsarray import layout_map
from lib.plotting import move_legend_fig_to_ax
from pathlib import Path
from matplotlib import font_manager
import templateflow.api as tflow

from styles import styles as Styles

%load_ext autoreload
%autoreload 2
%matplotlib inline
plt.style.use("styles/presentation.mplstyle")
font_dirs = [Path.home() / ".fonts"]
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

In [None]:
# layout = BidsLayout(
#     ["../../derivatives/surfsample-0.1.0/", "../../derivatives/snakeanat-diffusion-v0.0.1/"],
#     cache=".cache",
#     reset_cache=True,
# )
layout = BidsLayout.load(".cache")
lh, rh = fetch_template_surface("fslr32k", layer="inflated", join=False)
mesh = fetch_template_surface("fslr32k", layer="inflated")
mask = fetch_mask("fslr32k")

### Gather all the surface sample files

In [None]:
sub_list = set(layout.entities["subject"]) - {"001", "003"}
md = pd.concat(
    [
        pd.read_csv("panss-pre.csv").assign(session="pre"),
        pd.read_csv("panss-post.csv").assign(session="post"),
    ]
).assign(subject=lambda df: df["participant_id"].map(lambda s: s[3:]))[
    lambda df: df["subject"].isin(sub_list)
]

In [None]:


def get_hems(img):
    l = img[img.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project(0).flatten()
    r = img[img.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project(0).flatten()
    return l, r


def load_data(layout):
    wcards = [wcard for wcard, vals in layout.entities.items() if len(vals) > 1]
    md = []

    idx = {
        wcard: dict(map(reversed, enumerate(layout.entities[wcard])))
        for wcard in wcards
    }
    counts = [len(layout.entities[wcard]) for wcard in wcards]

    data = None
    for path in tqdm.tqdm(layout):
        loc = [idx[wcard][path.entities[wcard]] for wcard in wcards]
        md.append(path.entities)
        img = cp.load(path)
        lh, rh = get_hems(img)
        if data is None:
            data = np.full((*counts, lh.shape[0] + rh.shape[0]), np.NaN)
            bound = lh.shape[0]
        data[(*loc, slice(None, bound))] = lh
        data[(*loc, slice(bound, None))] = rh

    coords = {
        wcard: [t[0] for t in sorted(idx[wcard].items(), key=lambda t: t[1])]
        for wcard in wcards
    }
    return xr.DataArray(data, dims=(*wcards, "vertex"), coords=coords)


ds = (
    xr.concat(
        [
            load_data(
                layout.get(
                    desc=["odi", "ndi", "fw"],
                    sub=sub_list,
                )
            ),
            load_data(
                layout.get(
                    desc=["FA", "MD", "L1", "RD"],
                    sub=sub_list,
                )
            ),
            load_data(
                layout.get(
                    suffix="thickness",
                    den="32k",
                    sub=sub_list,
                )
            ).expand_dims(desc=["thickness"]),
        ],
        dim="desc",
    )
    .to_dataset(name="surface")
    .merge(md.set_index(["subject", "session"]).to_xarray())
    .drop_sel(subject=["005", "012", "026", "037", "044", "047"])
)
ds.to_netcdf("checkpoint2.h5")
ds["surface_smooth"] = xr.concat(
    [
        mesh_smooth(
            ds["surface"].where(mask),
            surf=mesh,
            FWHM=smoothing,
            mask=mask,
            axis="vertex",

        ).expand_dims(smoothing=[smoothing])
        for smoothing in np.r_[5:15]
    ],
    dim="smoothing"
)
ds.to_netcdf("surface_data.2.h5")

### Data

In [None]:
ds = xr.open_dataset("surface_data.2.nc", chunks={})

In [None]:
surf_smooth = ds.sel(smoothing=6)["surface_smooth"]
dkmd = pd.read_csv("atlas-dkt_labels.tsv", sep="\t")


@layout_map(dims={"param": ds["desc"].rename(desc="param"), "roi": dkmd["label"]})
def sample_dk(path):
    dknii = cp.load(path)
    dkatlas = np.empty(32492 * 2)
    dkatlas[:32492] = (
        dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_LEFT"]].project().ravel()
    )
    dkatlas[32492:] = (
        dknii[dknii.struc["CIFTI_STRUCTURE_CORTEX_RIGHT"]].project().ravel()
    )
    result = np.empty((len(ds["desc"]), len(dkmd)))
    x = surf_smooth.sel(
        subject=path.entities["subject"], session=path.entities["session"]
    )
    for (i, param), (j, label) in it.product(
        enumerate(ds["desc"]), enumerate(dkmd["label"])
    ):
        result[i, j] = np.mean(x.sel(desc=param)[dkatlas == label])
    return result


dk_sample = (
    sample_dk(
        layout.get(
            suffix="dparc",
            subject=ds["subject"].data,
            atlas="dkt",
            space="fsLR",
            den="32k",
        ),
        wildcards=["subject", "session"],
    )
    .to_dataset(name="dk")
    .merge(dkmd.rename(columns={"label": "roi"}).set_index("roi").to_xarray())
)

In [None]:
dk_sample = xr.load_dataset("dk_sample.nc")

In [None]:
import scipy.stats as scs
scs.zscore((
    ds.sel(desc="RD",  session="post", smoothing=6)["surface"] - 
    ds.sel(desc="RD",  session="pre", smoothing=6)["surface"]
).pipe(np.abs).mean("vertex").to_dataframe()["surface"])

In [None]:
plot_hemispheres(
    lh,
    rh,
    (
        ds.sel(desc="FA",  session="post", smoothing=6)["surface"]
        .load()
        .data -
        ds.sel(desc="FA",  session="pre", smoothing=6)["surface"]
        .load()
        .data
    ),
    color_bar=True,
    color_range=(-0.5,0.5),
    label_text=list(ds["subject"].data),
    cmap="coolwarm",
    embed_nb=True,
    size=(1400, 5000),
    zoom=1.45,
    nan_color=(0.7, 0.7, 0.7, 1),
    cb__labelTextProperty={"fontSize": 12},
    interactive=False,
    background=(1, 1, 1),
    transparent_bg=False,
)

In [None]:
import dask.bag as db

from brainstat.stats.terms import FixedEffect, MixedEffect
from brainstat.stats.SLM import SLM
ds_ = ds.sel(session=["post"])#.drop_sel(subject=["008", "029"])
surface = (
    ds_["surface_smooth"]
    .stack(obs=("subject", "session"))
    .transpose("desc", "smoothing", "obs", "vertex")
)
df = (ds_["PANSS-N"] / ds_["PANNS-P"]).to_dataframe(name="panss-ratio").reset_index()
term_ses = FixedEffect(df["session"])
term_sub = MixedEffect(df["subject"])
ses_post = (df["session"] == "post").astype(int)
ses_pre = (df["session"] == "pre").astype(int)
term_panss = FixedEffect(df["panss-ratio"])

ses = {
    "model": term_ses + term_sub,
    "contrast": ses_post - ses_pre,
    "label": "ses",
}
interact = {
    "model": term_ses * term_panss + term_sub,
    "contrast": (ses_post * df["panss-ratio"]) - (ses_pre * df["panss-ratio"]),
    "label": "intr",
}
panss = {
    "model": term_panss,
    "contrast": df["panss-ratio"] * -1,
    "label": "panss",
}
slm = SLM(
    panss["model"],
    panss["contrast"],
    mask=mask,
    surf="fslr32k",
    correction=["rft", "fdr"],
    cluster_threshold=0.01,
    two_tailed=False,
)
slm.fit(np.asanyarray(surface.sel(desc="FA", smoothing=7)))
cluster_p = [*(np.copy(slm.P["pval"][idx]).T for idx in ["C", "P"]), np.copy(slm.Q)]
for clust in cluster_p:
    np.place(clust, np.logical_or(clust > 0.05, ~mask), np.nan)
    np.copyto(clust, slm.t[0], where=~np.isnan(clust))


plot_hemispheres(
    lh,
    rh,
    np.vstack(cluster_p),
    color_bar=True,
    # label_text=["Cluster p-values", "Peak p-values", "Vertex p-values"], cmap="autumn_r",
    cmap="autumn_r",
    embed_nb=True,
    size=(1400, 800),
    nan_color=(0.7, 0.7, 0.7, 1),
    cb__labelTextProperty={"fontSize": 12},
    interactive=False,
    transparent_bg=False,
)

In [None]:
slm

In [None]:
import dask.bag as db

from brainstat.stats.terms import FixedEffect, MixedEffect
from brainstat.stats.SLM import SLM

surface = (
    ds["surface_smooth"]
    .stack(obs=("subject", "session"))
    .transpose("desc", "smoothing", "obs", "vertex")
)

df = ds[["PANSS-N", "PANNS-P"]].to_dataframe().reset_index()
panss_var = "PANNS-P"
term_ses = FixedEffect(df["session"])
term_sub = MixedEffect(df["subject"])
ses_post = (df["session"] == "post").astype(int)
ses_pre = (df["session"] == "pre").astype(int)
term_panss = FixedEffect(df[panss_var])

models = {
    "panss": {
        "model": term_ses * term_panss + term_sub,
        "contrast": df[panss_var],
        "label": "panss",
    },
    "intr": {
        "model": term_ses * term_panss + term_sub,
        "contrast": (ses_post * df[panss_var]) - (ses_pre * df[panss_var]),
        "label": "intr",
    }
}


def compute_stats(ds):
    smoothing = ds["smoothing"].item()
    suffix = ds["suffix"].item()
    sign = ds["sign"].item()
    model = models[ds["model"].item()]
    slm = SLM(
        model["model"],
        model["contrast"] * sign,
        mask=mask,
        surf="fslr32k",
        correction=["rft", "fdr"],
        cluster_threshold=0.01,
        two_tailed=True,
    )
    try:
        slm.fit(np.asanyarray(surface.sel(desc=suffix, smoothing=smoothing)))
    except np.linalg.LinAlgError:
        entries = {
            "C": None,
            "P": None,
            "Q": None,
            "t": None,
            "clusid": None,
        }
    except IndexError as err:
        raise Exception(f"{smoothing=} {suffix=} {sign=} {model=}") from err
    else:
        entries = {
            "C": slm.P["pval"]["C"],
            "P": slm.P["pval"]["P"],
            "Q": slm.Q,
            "t": slm.t,
            "clusid": slm.P["clusid"][0],
        }
    if entries["clusid"] is not None:
        entries["clusid"] = entries["clusid"][0]
    for key, val in entries.items():
        if val is None:
            ds[key].data = np.full(ds[key].shape, np.nan)
            continue
        ds[key].data = val.reshape(ds[key].shape)
    # for l, clust in zip(("C", "P", "Q"), cluster_p):
    #     ds[l][:] = clust
        # if not clust.shape and clust.item() is None:
        #     ds[l][:] = np.nan
        # np.place(clust, np.logical_or(clust > 0.05, ~mask), np.nan)
        # np.copyto(clust, slm.t[0], where=~np.isnan(clust))
        # ds[l][:] = np.sum(~np.isnan(clust))
        # if np.any(np.isinf(stats[l]) | np.isnan(stats[l])):
    return ds



import dask.array as da
coords = {
    "smoothing": [7],#surface["smoothing"].values[::2], 
    "suffix": ["ndi", "odi", "fw", "thickness"],#surface["desc"].values,
    "sign": [1, -1],
    "model": ["panss", "intr"],
}
axes = {
    "vertex": 64984
}
variables = {
    "C": {
        "dims": ["vertex"],
        "dtype": float
    },
    "P": {
        "dims": ["vertex"],
        "dtype": float
    },
    "Q": {
        "dims": ["vertex"],
        "dtype": float
    },
    "t": {
        "dims": ["vertex"],
        "dtype": float
    },
    "clusid": {
        "dims": ["vertex"],
        "dtype": int
    }
    
}
dims = list(coords.keys())
comp_vars = {}
for label, v in variables.items():
    shape = tuple(len(x) for x in coords.values()) + tuple(axes[d] for d in v["dims"])
    chunks = (1,) * len(coords) + tuple(axes[d] for d in v["dims"])

    comp_vars[label] = xr.DataArray(
        da.empty(shape=shape, chunks=chunks, dtype=v["dtype"]),
        dims=dims + v["dims"],
        coords=coords
    )
template = xr.Dataset(comp_vars)
from dask.diagnostics import ProgressBar

stats = xr.map_blocks(compute_stats, template, template=template)
# stats.to_csv("stats.csv")

In [None]:
# with dask.config.set(scheduler='processes', num_workers=4):
with ProgressBar():
    stats.load()

In [None]:
stats.to_netcdf("panssp-stats.nc")

In [None]:
stats = xr.load_dataset("stats.nc")

In [None]:
df = (stats[["Q", "C", "P"]] < 0.05).sum("vertex").to_dataframe().reset_index()

In [None]:
df[df["C"] > 0]

In [None]:
# suffix="FA"
# for l, clust in zip(("C", "P", "Q"), cluster_p):
#     if not clust.shape and clust.item() is None:
#         continue
#     np.place(clust, np.logical_or(clust > 0.05, ~mask), np.nan)
#     np.copyto(clust, slm.t[0], where=~np.isnan(clust))
#     print(l, "=", np.sum(~np.isnan(clust)))


def do_plot(data):
    plot_hemispheres(
        lh,
        rh,
        # np.where(slm.t > 2, slm.t, np.nan),
        # slm.t,
        data,
        color_bar=True,
        # label_text=["ndi"],
        cmap="viridis",
        embed_nb=True,
        size=(1400, 1600),
        zoom=1.45,
        nan_color=(0.7, 0.7, 0.7, 1),
        cb__labelTextProperty={"fontSize": 12},
        transparent_bg=False,
    )


do_plot(
    np.vstack(
        [
            stats.sel(smoothing=smoothing, suffix="ndi", sign=-1, model="intr")["C"]
            .where(lambda da: da < 0.05)
            .data
            for smoothing in range(5, 15)
        ]
    )
)

In [None]:
plot_hemispheres(
    lh,
    rh,
    # np.where(slm.t > 2, slm.t, np.nan),
    # slm.t,
    stats.sel(
        model="panss", suffix=["ndi", "odi", "fw", "thickness"], sign=-1, smoothing=5)["C"]
    .where(lambda da: da < 0.05)
    .data,
    color_bar=True,
    # label_text=["ndi"],
    label_text=["NDI", "ODI", "$f_{FW}$", "Thickness"],
    cmap="viridis",
    embed_nb=True,
    size=(1400, 200 * 4),
    zoom=1.45,
    nan_color=(0.7, 0.7, 0.7, 1),
    cb__labelTextProperty={"fontSize": 12},
    transparent_bg=False,
)

In [None]:
coords = []
result = []
for suffix, sign in [
    ("ndi", -1),
    ("odi", -1),
    ("fw", 1),
    ("thickness", -1),
]:
    x = stats.sel(smoothing=5, suffix=suffix, sign=sign, model="panss")
    clusids = np.unique(x["clusid"])
    for i in clusids:
        vertices = np.nonzero(x["clusid"].data == i)
        if np.mean(x["C"][vertices]) <= 0.05:
            coords.append((suffix, i.astype(int)))
            result.append(
                ds.sel(
                    smoothing=6,
                    desc=suffix,
                )["surface"]
                .transpose("vertex", ...)[vertices]
                .mean("vertex")
            )

clusters = xr.concat(result, dim=pd.Index(coords, name=("param", "clusid"))).chunk("auto")

In [None]:
clusids = stats.sel(smoothing=5, suffix="fw", sign=1, model="panss")["clusid"].where(
    lambda da: ((da > 0) & (da < 4))
)
from nilearn import plotting, datasets

lhem, rhem = tflow.get(template="fsLR", density="32k", suffix="inflated")
plotting.view_surf(
    str(rhem),
    clusids[clusids.shape[0] // 2 :].data,
    # hemi="right",
    cmap="tab10",
    vmax=10,
    vmin=1,
    symmetric_cmap=False,
).resize(1500, 900)

In [None]:
import statsmodels.formula.api as smf


df = clusters.to_dataframe().drop(columns=["clusid", "param"]).reset_index()
df = (
    df.set_index(["subject", "session"])
    .join(
        md[["PANNS-P", "subject", "session"]].set_index(["subject", "session"]),
        how="inner",
    )
    .reset_index()
)
df
# result = smf.mixedlm("surface ~ session", df, groups=df["subject"]).fit()
# result.summary()
df[df["desc"].isin([param])]

In [None]:
import seaborn.objects as so

# fig, ax = plt.subplots(1, 1, figsize=(5, 3))

rois = {
    "frontal": "Frontal Lobe",
    "parietal": "Parietal Lobe",
    "occipital": "Occipital Lobe",
    "temporal": "Temporal Lobe",
    "cingulate": "Cingulate Gyrus",
    "insula": "Insula",
}
titles = {
    "FA": "FA",
    "ndi": "NDI",
    "odi": "ODI",
    "fw": r"$f_{fw}$",
    "thickness": "Thickness",
    "panss_n": "PANSS30-N",
    "PANNS-P": "PANSS30-P",
}
# for score, param in it.product(
#     ["panss_p", "panss_n"], ["odi", "ndi", "fw", "thickness"]
# ):
score="PANNS-P"
param="ndi"
ses_order = so.Nominal(order=["post", "pre"])
(
    so.Plot(
        df[df["desc"].isin([param])],
        x=score,
        y="surface",
        group="session",
        color="clusid",
    )
    .facet(col="clusid", wrap=2)
    .add(so.Dot(color="#666666"), fill="session", marker="session", legend=False)
    .add(
        so.Line(linewidth=3), so.PolyFit(order=1), linestyle="session", legend=False
    )
    .layout(size=(10, 10))
    .scale(linestyle=ses_order, marker=ses_order, fill=ses_order, color="tab10")
    .label(x=titles[score], y=titles[param], title="")
    .theme(plt.rcParams)
    .save(f"~/tsclient/khangrp4/Downloads/{score}-{param}-rft.png")
)

In [None]:
ROIS_L = np.array([
    2, 3, 8, 9, 10, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 35
])
ROIS = np.hstack([ROIS_L, ROIS_L + 35])
ROIS

In [None]:
zscores = xr.apply_ufunc(
    scs.zscore,
    dk_sample["dk"].sel(param="FA").diff("session").pipe(np.abs),
    input_core_dims=[["subject"]],
    output_core_dims=[["subject"]],
    vectorize=True,
).pipe(np.abs)
outliers = zscores.max("roi").where(lambda da: da > 3, drop=True)["subject"].data
dk_sample["dk"].sel(param="FA").diff("session").pipe(np.abs).stack(
    subroi=("subject", "roi")
).sel(
    subroi=zscores.sel(subject=outliers)
    .where(lambda da: da > 3, drop=True)
    .stack(subroi=("subject", "roi"))
    .dropna("subroi")["subroi"]
).where(
    lambda da: da > 0.05, drop=True
)

In [None]:
hemi_grouped = dk_sample.groupby("lobe").mean("roi")
# result = np.empty(len(hemi_grouped["name"]))
result = []
# for i, name in enumerate(hemi_grouped["name"]):
df = (
    ds[["PANSS-N", "PANNS-P"]]
    .rename({"PANSS-N": "panss_n", "PANNS-P": "panss_p"})
    .assign(panss_ratio=lambda ds: ds["panss_n"] / ds["panss_p"])
    .merge(hemi_grouped)
    .to_dataframe()
    .reset_index()
)
df.to_csv("sampled-dk-grouped-lobe.csv")
# lm1 = smf.mixedlm("dk ~ session*panss_ratio", data=df, groups="subject").fit()
# lm2 = smf.mixedlm("dk ~ session+panss_ratio", data=df, groups="subject").fit()
# lm2.random_effects_cov
#     result.append(lm.pvalues)
# result


In [None]:
import seaborn.objects as so

# fig, ax = plt.subplots(1, 1, figsize=(5, 3))

rois = {
    "frontal": "Frontal Lobe",
    "parietal": "Parietal Lobe",
    "occipital": "Occipital Lobe",
    "temporal": "Temporal Lobe",
    "cingulate": "Cingulate Gyrus",
    "insula": "Insula",
}
titles = {
    "FA": "FA",
    "ndi": "NDI",
    "odi": "ODI",
    "fw": r"$f_{fw}$",
    "thickness": "Thickness",
}

for param in (
     ["odi", "ndi", "fw", "thickness"]
):
    (
        so.Plot(
            df[df["param"].isin([param])],
            x="session",
            y="dk",
            group="subject",
            color="lobe",
        )
        .facet(col="lobe",  wrap=2, order=list(rois))
        .add(so.Dot(fill=None), legend=False)
        .add(so.Line(linewidth=5), so.PolyFit(order=1), group=None, legend=False)
        .add(so.Line(linestyle="dashed", alpha=0.3), legend=False)
        .layout(size=(10, 10))
        .scale(x=so.Nominal(order=["pre", "post"]))
        .label(x="Session", y=titles[param])
        .theme(plt.rcParams)
        .label(title=rois.get)
        .save(f"~/tsclient/khangrp4/Downloads/session-{param}.png")
    )

In [None]:
import seaborn.objects as so

# fig, ax = plt.subplots(1, 1, figsize=(5, 3))

rois = {
    "frontal": "Frontal Lobe",
    "parietal": "Parietal Lobe",
    "occipital": "Occipital Lobe",
    "temporal": "Temporal Lobe",
    "cingulate": "Cingulate Gyrus",
    "insula": "Insula",
}
titles = {
    "FA": "FA",
    "ndi": "NDI",
    "odi": "ODI",
    "fw": r"$f_{fw}$",
    "thickness": "Thickness",
    "panss_n": "PANSS30-N",
    "panss_p": "PANSS30-P",
}
ses_order = so.Nominal(order=["post", "pre"])
.scale(linestyle=ses_order, marker=ses_order, fill=ses_order)
for score, param in it.product(
    ["panss_p", "panss_n"], ["odi", "ndi", "fw", "thickness"]
):
    (
        so.Plot(
            df[df["param"].isin([param])],
            x=score,
            y="dk",
            group="session",
            color="lobe",
        )
        .facet(col="lobe", wrap=2, order=list(rois))
        .add(
            so.Dot(color="#666666", pointsize=3),
            marker="session",
            fill="session",
            legend=False,
        )
        .add(
            so.Line(linewidth=1.5),
            so.PolyFit(order=1),
            linestyle="session",
            legend=False,
        )
        .layout(size=(5, 5))
        .scale(linestyle=ses_order, marker=ses_order, fill=ses_order)
        .label(x=titles[score], y=titles[param])
        .theme(plt.rcParams | {"axes.titlesize": 14, "axes.labelsize": 12})
        .label(title=rois.get, linestyle=str.capitalize, marker=str.capitalize)
        .save(f"~/tsclient/khangrp4/Downloads/{score}-{param}.png")
        # .show()
    )

In [None]:
import seaborn.objects as so

# fig, ax = plt.subplots(1, 1, figsize=(5, 3))

rois = {
    "superiorfrontal": "Superior Frontal",
    "rostralmiddlefrontal": "Rostral Middle Frontal",
    "caudalmiddlefrontal": "Caudal Middle Frontal",
    "parstriangularis": "Pars Triangularis",
    "parsopercularis": "Pars Opercularis",
    "parsorbitalis": "Pars Orbitalis",
    "precentral": "Precentral",
    "postcentral": "Postcentral",
    "insula": "Insula",
    "superiortemporal": "Superior Temporal",
    "middletemporal": "Middle Temporal",
    "precuneus": "Precuneus"
}
_ = (
    so.Plot(
        df[df["name"].isin(rois) & df["param"].isin(["FA"])],
        x="panss_n",
        y="dk",
        group="session",
        linestyle="session",
        color="name",
    )
    .facet(col="name",  wrap=2, order=list(rois))
    .add(so.Dot(fill=None, color="#2f2f2f"), marker = "session", legend=False)
    .add(so.Line(), so.PolyFit(order=1), legend=False)
    # .add(so.Line(linestyle="dashed", alpha=0.5), legend=False)
    .layout(size=(10, 15))
    # .scale(x=so.Nominal(order=["pre", "post"]))
    .label(x="PANSS-N", y="FA")
    .theme(plt.rcParams)
    .label(title=rois.get)
    .plot(True)
)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(2,1))

from matplotlib import cm, colors
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
def add_legend(fig, labels, cmap=None, size=None, cmax=None, **kwargs):
    if cmap is None:
        cmap = cm.get_cmap(plt.rcParams["image.cmap"])
    if isinstance(cmap, colors.ListedColormap) and cmax is None:
        cmax = len(cmap.colors)
    colorpoints = colors.Normalize(0, cmax)(np.r_[: len(labels)])
    handles = [
        Line2D(
            [0],
            [0],
            linestyle="dashed",
            marker="X",
            markeredgecolor="#666666",
            fillstyle="none",
            label="Pre",
            markersize=size,
        ),
        Line2D(
            [0],
            [0],
            marker="o",
            label="Post",
            markeredgecolor="#666666",
            markerfacecolor="#666666",
            fillstyle="full",
            markersize=size,
        ),
    ]

    return fig.legend(
        handles=handles,
        **kwargs,
    )

add_legend(fig, ["Pre", "Post"], loc="center")
ax.axis(False)

fig.savefig("/home/ROBARTS/pvandyk2/tsclient/khangrp4/Downloads/clin-legend.png", transparent=True)


In [None]:

clusters.sel(param="FA", clusid=2, session="pre").load().to_dataframe()

In [None]:
# suffix="FA"
# for l, clust in zip(("C", "P", "Q"), cluster_p):
#     if not clust.shape and clust.item() is None:
#         continue
#     np.place(clust, np.logical_or(clust > 0.05, ~mask), np.nan)
#     np.copyto(clust, slm.t[0], where=~np.isnan(clust))
#     print(l, "=", np.sum(~np.isnan(clust)))

keys = [
    ("ndi", -1),
    ("fw", 1),
    ("FA", -1),
    ("L1", 1),
    ("MD", 1),
    ("RD", 1),
    ("thickness", -1),
]
da = stats.stack(effect=["suffix", "sign"]).sel(effect=keys)

def do_plot(data):
    return plot_hemispheres(
        lh,
        rh,
        # np.where(slm.t > 2, slm.t, np.nan),
        # slm.t,
        data,
        color_bar=True,
        # label_text=["ndi"],
        label_text=list(map(str,keys)),
        cmap="viridis",
        embed_nb=True,
        size=(1400, 200 * data.shape[0]),
        zoom=1.45,
        nan_color=(0.7, 0.7, 0.7, 1),
        cb__labelTextProperty={"fontSize": 12},
        transparent_bg=False,
    )


do_plot(
    da.sel(smoothing=11, model="intr")["C"]
    .where(lambda da: da < 0.05)
    .transpose("effect", "vertex")
    .data
)

In [None]:
np.ma.mean(np.ma.masked_where(np.isnan(slm.t), slm.t))