In [None]:
from __future__ import annotations
import pickle
import itertools as it
import numpy as np
import pandas as pd
import xarray as xr
import networkx as nx

import os
import subprocess as sp
from pathlib import Path
import copy
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import more_itertools as itx
import multiprocessing as mp
import tqdm
import attrs
# import sklearn
import shutil
import plotly.express as px
import functools as ft
from typing import Any, List
import plotly.graph_objects as go
import ipywidgets as widgets
import IPython.display as D
import imageio.v3 as iio

In [None]:
os.chdir("..")

In [None]:
%matplotlib inline
plt.rc('ytick', labelsize=11, color="#101010")
plt.rc('xtick', labelsize=11, color="#101010")
plt.rc("font", family="Roboto")
plt.rc("figure", titlesize=14, titleweight='bold', edgecolor='black', dpi=200, facecolor="white")
plt.rc("axes", edgecolor="#a0a0a0", titlecolor='#404040', facecolor="#f0f0f0")
plt.rc("axes.spines", top=False, left=False, bottom=False, right=False)

def letter_label(panel, letter, x=0, y=1, transform=None):
    return panel.text(x, y, letter, transform=transform, size=20, weight='bold', color='#303030')

from matplotlib import font_manager
font_dirs = [Path.home() / '.fonts']
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

In [None]:
ROOT = Path("/scratch/knavynde/newtopsy")
def get_subj_metadata():
    return (
        pd.read_csv(ROOT / "participants.tsv", sep="\t")
        .assign(subject=lambda df: df['participant_id'].map(lambda s: int(s[4:])))
        .set_index("subject")
        .assign(group=lambda df: df["phenotype"].map(cats))
    )

cats = {
    "3": "Treatment 3+ yr",
    "CHR": "High risk",
    "HC": "HC",
    "FEP": "FEP",
    "Control": "HC",
    "Patient": "Patient"
}

from bids import BIDSLayout

def get_participants(participant_file):
    with open(participant_file) as f:
        subs = [match.group(1) for l in f.read().splitlines() if (match := re.match(r'sub-(.*)', l)) and match.group(1) != "080"]
    return subs

layout = BIDSLayout(ROOT, derivatives=True, database_path=ROOT / ".pybids")
layout_get = ft.partial(
    layout.get,
    subject=get_participants(ROOT / 'derivatives/snakedwi-0.1.0/participants.tsv'),
    datatype="dwi",
    suffix="connectome",
    atlas="bn246",
)

### Utility Fuctions

In [None]:
from notebooks.utils import hex_to_rgb, get_lut, lut_label, titleize, distribution_plot, figures_to_html, plotly_tabulate, NbCache, underscore as __, error_line
from notebooks.adjacency_matrix import group_outer, AdjacencyMatrix, filter_edges, community_sort
from notebooks.bn_metadata import read_metadata
from notebooks.graph_metrics import global_efficiency, local_efficiency
from notebooks.plotly_helpers import plotly_grid, matplotlib_to_plotly
nb_cache = NbCache("connectomics")

In [None]:
def filter_logile(bin: int, num_bins: int = 10):
    def inner(matrix):
        if bin >= num_bins:
            raise ValueError("bin must be less then num_bins")
        masked = np.ma.masked_equal(matrix, 0)
        log = np.ma.log10(masked)
        threshold = 10 ** (((log.max() - log.min()) * bin / num_bins) + log.min())
        return matrix < threshold
    return inner

            
def community_sort(adj: AdjacencyMatrix):
    components = nx.community.greedy_modularity_communities(
        adj.graph, weight="weight"
    )
    order = list(it.chain.from_iterable(components))
    return adj.with_metadata(adj.metadata.reindex(index=np.array(order)))


def metadata_ds():
    participants = list(
        map(int, get_participants(ROOT / 'derivatives/snakedwi-v0.1.0/participants.tsv'))
    )
    return xr.merge([
        get_subj_metadata().loc[lambda df: df.index.isin(participants)].to_xarray(),
        (
            read_metadata()
            .reset_index()
            .rename(columns={"Label ID": "node"})
            .set_index("node")
            .to_xarray()
        )
    ])


@attrs.frozen
class Subject:
    adj: AdjacencyMatrix
    id: int
    group: str
    weight: str

    @classmethod
    def from_bids_entry(cls, __bids_entry, /, filter_level=1):
        sub = int(__bids_entry.entities['subject'])
        metadata = get_subj_metadata()
        if not sub in metadata.index:
            # print(f"Dropped {sub} (Not found in metadata)")
            return None
        adj = (
            AdjacencyMatrix(
                raw=np.genfromtxt(__bids_entry.path, delimiter=","),
                metadata=read_metadata("resources/brainnetome-regions.csv"),
            )
            .mask_diagonal()
            .mask_equal(0)
            .mask_where(filter_logile(filter_level))
            # .mask_where_meta(MetadataMasks.src_and_dest("hemisphere").equals("L"))
        )
        adj.props["distance"] = np.ma.filled(1/adj.raw, np.NaN)
        # dropped = drop(sub) if callable(drop) else drop
        # df.drop(index=dropped, columns=dropped, inplace=True)
        return cls(
            adj = adj,
            id = sub,
            group = metadata.loc[sub, "group"],
            weight = __bids_entry.entities["desc"],
        )
    
from matplotlib import cm
spectral = matplotlib_to_plotly(cm.get_cmap('nipy_spectral'), 255)

### Single Subject

Here, test a single subject and view the connection density as a histogram

In [None]:
adj = community_sort(
    Subject.from_bids_entry(itx.one(layout_get(subject="001", desc="avgFA")))
    .adj
    .mask_where(filter_logile(0))
)

adj = adj.with_metadata(adj.metadata.sort_values(["Lobe", "hemisphere"]))
# adj = adj.update(np.ma.log10(adj.raw))
adj.plot(labels=["Lobe", "Long Name"])

In [None]:
def get_conmat(group):
    adjs = []
    md = get_subj_metadata()
    for entry in layout_get(desc="sift2"):
        subj = Subject.from_bids_entry(entry)
        if not subj:
            continue
        if md.loc[subj.id]["group"] != group:
            continue
        adj = subj.adj.mask_where(filter_logile(1))
        adj = adj.with_metadata(adj.metadata.sort_values(["Lobe", "hemisphere"]))
        adj = adj.update(np.ma.log10(adj.raw))
        adjs.append(adj)

    conmat = np.ma.empty((*adjs[0].masked.shape, len(adjs)))
    for i, adj in enumerate(adjs):
        conmat[..., i] = adj.masked
    return conmat

hc = get_conmat("HC")
pt = get_conmat("Patient")

In [None]:
md = get_subj_metadata()
subjects = [int(entry.get_entities()['subject']) for entry in layout_get(desc="sift2")]
part_ids = md.loc[subjects].sort_values(["group", "participant_id"])["participant_id"].to_numpy()

In [None]:
part_ids.shape

In [None]:
combined = np.ma.dstack([hc, pt])
groups = np.hstack([np.ones(hc.shape[-1]), np.zeros(pt.shape[-1])])
combined.shape

In [None]:
def scree_plot(eigenvalues):
    return sns.lineplot(x=np.arange(eigenvalues.shape[0]), y=eigenvalues)

from sklearn.decomposition import PCA
data = combined.filled(0)

pca = PCA(n_components=10, svd_solver='randomized')
pt_square = data.reshape(-1, data.shape[-1]).T
L = pca.fit_transform(pt_square)
print(pca.explained_variance_ratio_)
df = pd.DataFrame(L)
df['groups'] = groups
df['id'] = part_ids
comps = pca.components_[4].reshape(*data.shape[:2])

# scree_plot(pca.singular_values_)
# sns.heatmap(comps, cmap="vlag")
px.scatter_3d(df, x=2, y=3, z=4, color='groups', height=1000, hover_name='id')
# sns.scatterplot(x=L[:, 0], y=L[:, 1], hue=groups)


In [None]:
fig = plt.figure(figsize=(10, 5))
axes = fig.subplots(1, 2)

def consensus_dist(conmat):
    if np.ma.is_masked(conmat):
        conmat = conmat.filled(0)
    mean = np.ma.masked_equal(np.mean(conmat, -1), 0)
    std = np.std(conmat, -1)
    return np.ma.masked_where(np.median(conmat, -1) == 0, std / mean)


sns.heatmap(consensus_dist(hc).filled(0), ax=axes[0])

sns.heatmap(consensus_dist(pt).filled(0), ax=axes[1])


#### Connection Density

Density of connections as threshold increases

In [None]:
# c = np.genfromtxt(
#     layout.get(
#         subject="001", datatype="dwi", atlas=ATLAS, suffix="connectome", desc=WEIGHT
#     )[0].path, 
#     delimiter=","
# )
def connection_density():
    densities = []
    for entry in layout_get(desc="sift2"):
        subj = Subject.from_bids_entry(entry)
        if not subj:
            continue
        for i in range(10):
            adj = (
                subj
                .adj
                .mask_where(filter_logile(i))
            )

            densities.append((subj.id, i, nx.density(adj.graph)))
    return pd.DataFrame(
            densities,
            columns=["subject", "threshold", "density"]
        ).set_index(["subject", "threshold"])
densities = connection_density()

In [None]:
ds['group']

In [None]:
ds = xr.merge([
    get_subj_metadata().to_xarray(),
    densities.to_xarray(),
])

error_line(
    __.pipe(
        dict(
            density=ds.groupby("group").mean("subject")["density"],
            std=ds.groupby("group").std("subject")["density"],
        ),
        lambda _: xr.merge([_]).to_dataframe().reset_index()
    ),
    x="threshold",
    y="density",
    color="group",
    err="std",

)

#### Weight distribution

In [None]:

# Weight distribution plot
bins = __.pipe(
    adj.filled,
    np.concatenate,
    len,
    np.sqrt,
    np.arange,
)
# bins = (bins - np.min(bins))/np.ptp(bins)
fig, axes = plt.subplots(1,2, figsize=(15,5))

# Distribution of raw weights
rawdist = sns.histplot(adj.masked.flatten(), bins=bins, kde=False, ax=axes[0])
rawdist.set(xlabel='Correlation Values', ylabel = 'Density Frequency')

# Probability density of log10
log10dist = sns.histplot(np.log10(adj.masked).flatten(), kde=False, ax=axes[1])
log10dist.set(xlabel='log(weights)')

#### Edge Mass Function

The total weight held by all edges of a given weight bin

In [None]:

def edge_mass_function():
    cum_weight = []
    for entry in layout_get(desc="sift2"):
        subj = Subject.from_bids_entry(entry)
        if not subj:
            continue
        adj = (
            subj
            .adj
            .mask_where(filter_logile(0))
        )
        arr = adj.masked.filled(0)
        bins = __.pipe(
            # arr,
            # np.concatenate,
            # len,
            # np.sqrt,
            246,
            np.arange,
            lambda _: np.divide(_, np.max(_)),
            lambda _: np.multiply(_, np.ptp(arr)),
            lambda _: np.add(_, np.min(arr))
        )
        digitized = np.digitize(arr, bins)
        for i in np.unique(digitized):

            cum_weight.append((subj.id, i, arr[digitized == i].sum()))
    return pd.DataFrame(
        cum_weight,
        columns=["subject", "bin", "weight"]
    ).set_index(["subject", "bin"])
cum_weight = edge_mass_function()


In [None]:
ds = xr.merge([
    get_subj_metadata().to_xarray(),
    cum_weight.to_xarray()
])

error_line(
    __.pipe(
        dict(
            weight=ds.mean("subject")["weight"],
            std=ds.std("subject")["weight"],
        ),
        lambda _: xr.merge([_]).to_dataframe().reset_index()
    ),
    x="bin",
    y="weight",
    # color="group",
    err="std",

)

## Multiple Subjects

Loop through all subjects and gather various metrics

### Setup

In [None]:
def subject_graphs(filter_level = 1, drop = [], subject = None):
    _sub = {"subject": subject} if subject is not None else {}
    for bidsfile in layout_get(desc=["sift2", "avgFA", "medR1"], **_sub):
        if not (subj := Subject.from_bids_entry(bidsfile, filter_level)):
            print(f"Dropped {bidsfile.entities['subject']} (Not found in metadata)")
            continue

        yield subj
            


def subject_properties(subj: Subject):
        try:
            G = subj.adj.graph
            largest_connected_comp = max(nx.connected_components(G), key=len)
            return {
                "subject": subj.id,
                "category": subj.group,
                "weight": subj.weight,
                "degree":np.mean([*zip(*G.degree(weight="weight"))][1]),
                "num_regions": len(G.nodes),
                # "dropped_regions": list(drop_regions),
                # "num_connected_comps": nx.number_connected_components(G),
                # "largest_connected_comp": len(largest_connected_comp),
                # "density": nx.density(G),
                # "transitivity": nx.transitivity(G),
                "global_efficiency": global_efficiency(G, weight="distance"),
                # "local_efficiency": local_efficiency(G, weight="distance"),
                # "shortest_path": nx.average_shortest_path_length(G.subgraph(largest_connected_comp), weight="distance"),
            }
        except KeyError as err:
            print(subj)
            raise err


def subject_df(drop_regions = []):
    rows = []
    with mp.Pool() as pool:
        graphs = list(subject_graphs(drop=drop_regions))
        rows = pool.map(subject_properties, graphs)
    return pd.DataFrame(rows)


def xproperty_rank(ds, columns=[], inverse_columns=[]):
    from scipy.stats import rankdata
    def rank(ds, column, inverse=False):
        data = ds[column].data
        ranked = rankdata(data, axis=-1) / data.shape[-1]
        return ds.assign({column+"_rank": (ds.dims, ranked)})
    for column in columns:
        ds = rank(ds, column)
    for column in inverse_columns:
        ds = rank(ds, column, inverse=True)
    return ds

def xhubness(ds):
    cols = ["betweenness", "degree"]
    inv_cols = ["path_length", "clust_coeff"]
    return (
        xproperty_rank(ds, columns=cols, inverse_columns=inv_cols)
        .assign(
            hubness = lambda ds: (
                ds.dims,
                __.pipe(
                    cols + inv_cols,
                    __.map(lambda s: s+"_rank"),
                    __.map(ds.get),
                    list,
                    lambda _: np.sum(_, axis=0)
                ),
            )
        )
    )

@nb_cache("data")
def get_data(thresholds = range(10)):
    return __.pipe(
        (
            (threshold, subject_graphs(filter_level=threshold))
            for threshold in thresholds
        ),
        __.starmap(lambda threshold, graphs: (
            pd.DataFrame(
                graph.adj.masked,
                index=[
                    pd.Index([graph.id], name="subject").repeat(
                        len(index := graph.adj.metadata.index)
                    ),
                    pd.Index([graph.weight], name="weight").repeat(len(index)),
                    pd.Index([threshold], name="threshold").repeat(len(index)),
                    index.rename("src"),
                ],
                columns=index.rename("dest")
            ) for graph in graphs
        )),
        itx.flatten,
        pd.concat,
        lambda df: (
            df
            .stack()
            .to_xarray()
        ),
        lambda da: xr.merge(
            [
                dict(
                    adj=da,
                    group=get_subj_metadata()['group'],
                )
            ],
            join="inner",
        )
    )


def col_row_insert(matrix, pos, value):
    return np.insert(np.insert(matrix, pos, value, axis=0), pos, value, axis=1)

def _get_params(avg, weight, group, threshold):
    adj = (
        AdjacencyMatrix(
            raw=avg
                .sel(weight=weight, group=group)
                .squeeze()
                .fillna(0)
                .pipe(lambda _: col_row_insert(_, 0, 0))
                .data,
            metadata=read_metadata(),
        )
        .mask_diagonal()
        .mask_equal(0)
        .mask_where(filter_logile(threshold))
        # .mask_where_meta(MetadataMasks.src_and_dest("hemisphere").equals("L"))
    )
    adj.props["distance"] = np.ma.filled(1/adj.raw, np.NaN)
    return graph_params(
        adj, weight=str(weight.data), group=str(group.data), threshold=threshold
    )

@nb_cache('group_avg_params')
def group_avg_params():
    avg = get_data().sel(threshold=0).groupby("group").mean().to_array()
    
    with mp.Pool() as pool:
        return pd.concat(
            pool.starmap(
                _get_params,
                it.product(avg, avg['weight'], avg['group'], range(10))
            )
        )
    # dropped = drop(sub) if callable(drop) else drop
    # df.drop(index=dropped, columns=dropped, inplace=True)
    
def nodal_properties():
    return pd.read_csv('results/nodal_props.tsv', sep='\t', index_col=0)

### Subject wide summary

In [None]:
@nb_cache("subject_full_df")
def subject_full_df():
    return subject_df()

In [None]:
df = subject_full_df()

In [None]:
#| label: global-props
#| fig-cap: |
#|   Global network properties compared across study groups. The top row shows the
#|   average network-wide degree, and the bottom row shows the average global
#|   efficiency. In the left column, networks are weighted by sift2 corrected streamline
#|   count, and in the right column, by average FA along the fiber.
def subject_distributions(col, weight):
    fig = px.box(
        df[(df["weight"] == weight) & (df["subject"] != 43)],
        x="category",
        color="category",
        y=col,
        points="outliers",
        width=584,
        height=400,
        labels={
            "num_regions": "# Regions",
            "category": "Group"
        },
        title=titleize(col),
        hover_data=["subject"]
    )
    fig.update_layout(
        margin=dict(l=50, r=50, t=50, b=50),
        showlegend=False
    )
    return fig
cols = [
    # "transitivity",
    # "efficiency",
    # "density",
    # "num_connected_comps",
    # "largest_connected_comp",
    "degree",
    "global_efficiency"
]
# plotly_tabulate(lambda _: subject_distributions(_, "sift2"), cols)
fig = plotly_grid(subject_distributions, cols, ["sift2", "avgFA"], vspacing=0.15)
# fig = subject_distributions("degree", "avgFA")
fig.update_layout(
    showlegend=False,
    # title_text="Global Graph properties",
    title = dict(
        # text=f'<b>{titleize(fig.layout["title"]["text"])}</b>',
        text="",
        font_family="roboto",
        font_color="#58595b",
        font_size=18,
        xanchor="center",
        x = 0.5,
        xref="paper",
    ),
    height=600,
    width=1000,
    margin=dict(
        t=50,
        b=50,
        l=50,
        r=50,
    ),
    template="seaborn"
)
D.Image(fig.to_image(format='png'))


# figures_to_html([
#     subject_distributions(col)
#     for col in [
#         "transitivity",
#         "efficiency",
#         "density",
#         "num_connected_comps",
#         "largest_connected_comp",
#         "degree",
#     ]
# ], "pages/bn246/subject_distributions.html")

### Subject wide distributions

In [None]:
node_props = nodal_properties()

In [None]:
cols = [
    "betweenness",
    "clust_coeff",
    "path_length",
    "degree",
]
indexed = property_rank(
    node_props.set_index(["weight", "category", "subject", "node"]),
    columns=cols
)
def hubness_property_rank_distribution(col, weight):
    return distribution_plot(indexed.loc[weight], x=col+"_rank", y=col)
# figures_to_html(
#     [
#         distribution_plot(indexed, x=col+"_rank", y=col) 
#         for col in cols
#     ],
#     filename="pages/bn246/nodal_distributions.html"
# )

In [None]:
cols = [
    "betweenness",
    "degree",
    "clust_coeff",
    "path_length",
]
plotly_tabulate(lambda _: hubness_property_rank_distribution(_, "sift2"), cols)
# fig = plotly_grid(hubness_property_rank_distribution, cols, ["sift2", "avgFA"])
# fig.update_layout(
#     showlegend=False,
#     title_text="Node Ranking Distributions",
#     title_xanchor="left",
#     title_x = 0,
#     title_xref="paper",
#     height=800,
#     width=1000,
#     template="seaborn"
# )
# D.Image(fig.to_image())

In [None]:
def hub_disruption_index(group, cols):
    df = nodal_properties()
    indexed = df.set_index(["category"])
    hc = indexed.loc["HC"]
    avg = hc.groupby("node").mean()[cols]
    group_values = indexed.loc[group].reset_index().pivot(columns="node", index="subject", values=cols)
    delta = group_values - avg
    delta = pd.DataFrame(delta.stack(), columns=["delta"])
    delta["avg"] = avg.reindex(delta.index, level="node")
    delta["Delta Total"] = delta.reset_index().groupby("subject").sum()["delta"].reindex(delta.index, level="subject")
    return delta.reset_index()

In [None]:
cols= [
    "betweenness",
    # "clust_coeff",
    # "path_length",
    # "degree",
]
plotly_tabulate(
    lambda group: plotly_tabulate(
        px.scatter(
            hub_disruption_index(group, col),
            x="avg",
            y="delta",
            color="Delta Total",
            trendline="ols",
            title=titleize(col),
            hover_data=["subject", "node"]
        )
        for col in cols
    ),
    cats.values()
)

### Hubness

In [None]:
def get_hubs():
    return xhubness(
        nodal_properties()
        .set_index(['weight', 'threshold', 'subject', 'node'])
        .to_xarray()
    )


def get_avg_hubs():
    return xhubness(
        group_avg_params()
        .set_index(['weight', 'group', 'threshold', 'node'])
        .to_xarray()
    )
    

In [None]:
import scipy
mds = metadata_ds()
hub_data = (
    mds
    .merge(get_hubs())
    .assign(
        node_uid=lambda ds: ds["Name"] + "_" + ds["hemisphere"],
        avg_hubness = lambda ds: ds["hubness"].mean(dim="subject"),
    )
)

avg_hub_data = (
    mds
    .drop_dims("subject")
    .merge(get_avg_hubs())
    .assign(node_uid=lambda ds: ds["Name"] + "_" + ds['hemisphere'])
)

def get_labels(ds, fields, repeats):
    labels = (
        np.repeat(hub_data[fields].to_pandas().to_numpy(), repeats, axis=0)
        .reshape((-1, repeats, len(fields)))
    )
    template = "<br>".join(f"<b>{name}:</b> %{{customdata[{i}]}}" for i, name in enumerate(fields))
    return labels, template
# cols = ["degree", "clust_coeff", "path_length", "betweenness"]
import plotly.express as px

labels, template = get_labels(hub_data, ["Long Name", "Gyrus", "Lobe", "hemisphere"], len(hub_data["subject"]))


# fig.data[0]["hovertemplate"] += "<br>" + template
# fig.update_traces(
#     customdata=labels,
# )
# fig.show()

We defined a composite hubness score using 4 individual measures of hubness: degree, betweenness, clustering coefficient, and shortest path length. Each node in each individual subject was rank-ordered with respect to each of the four metrics and assigned a score between 0 and 1 corresponding to the rank. These scores were summed together to get a hubness score between 0 and 4. Rank-ordering the nodes of each subject by hubness score gives the hubness rank for each subject. The hubness ranking varied little across the subjects (figure TBD).

The hubness scores from each subject were averaged to get the group hubness scores for each node. The group-wise hubness rankings are compared below.

Note that if, instead of the subject-wise analysis described above, the subject connectomes are first combined into group-consensus connectivity matrices, a large amount of difference manifests between groups, especially when FA is used as the weight (figure TBD)

In [None]:
#| label: hubness-ranking
#| fig-cap: |
#|   Average hubness scores across group members of brainnetome atlas nodes. Nodes are
#|   sorted according to increasing hubness in healthy controls. Nodes in disease groups
#|   are shown in the same order for comparison.
def template_metadata(ds, fields):
    labels = np.dstack([ds[field] for field in fields] or [[]])
    template = "<br>".join(
        f"<b>{name}:</b> %{{customdata[{i}]}}" for i, name in enumerate(fields)
    )
    return labels, template

def hubness_score(weight, sort_weight=None):
    hubness_col = "hubness"
    get_ds = lambda _weight: (
        hub_data
        .swap_dims({"node": "node_uid"})
        .groupby("group")
        .mean()
        .sel(weight=_weight, threshold=abs0)
        .reindex(group=["HC", "FEP", "High risk", "Treatment 3+ yr"])
    )
    _ds = get_ds(weight)
    if sort_weight is None or sort_weight == weight:
        fig = px.imshow(
            _ds.sortby(_ds.sel(group="HC")[hubness_col])[hubness_col],
            aspect='auto',
            color_continuous_scale=itx.nth(zip(*spectral), 1),
        )
    else:
        _sort_ds = get_ds(sort_weight)
        fig = px.imshow(
            _ds.sortby(_sort_ds.sel(group="HC")[hubness_col])[hubness_col],
            aspect="auto",
            color_continuous_scale=itx.nth(zip(*spectral), 1),
        )

    labels, template = template_metadata(hub_data, ["hemisphere", "Lobe", "Gyrus", "Long Name"])
    fig.data[0].customdata = labels
    fig.data[0].hovertemplate += "<br>" + template
    return fig

# fig.update_xaxes(scaleanchor="y", scaleratio=1, type="category")
fig = plotly_grid(
    lambda _: hubness_score(_),
    ["sift2", "avgFA"],
    vspacing=0.35
)
fig.update_layout(
    showlegend=False,
    # coloraxis_showscale=False,
    title=dict(
        text="Hubness Scores",
        xanchor="left",
        x = 0,
        xref="paper",
    ),
    height=600,
    width=1500,
    margin_r=20,
    # paper_bgcolor="rgba(0,0,0,0)",
)
# fig.update_xaxes(
#     tickfont_size=5,
# )
D.Image(fig.to_image())

In [None]:
import nibabel as nib
from nilearn import datasets, plotting
def show_atlas(hem):
    infl_left = itx.one(layout.get(subject="001", suffix="inflated", hemi=hem))
    dparc = nib.load(
        itx.one(layout.get(subject="001", suffix="dparc", atlas="bn210", hemi=hem))
    )
    data = dparc.get_arrays_from_intent('NIFTI_INTENT_LABEL')[0]
    ds = hub_data.sel(weight="sift2").groupby("group").mean()
    h = xr.concat(
        [
            xr.DataArray([0], dims=("node",), coords={"node": [0]}),
            ds.sel(threshold=3, group="HC")['hubness'],
        ],
        dim="node"
    )
    lookup = h.loc[data.data]
    fsaverage = datasets.fetch_surf_fsaverage()
    return plotting.view_surf(
        infl_left.path,
        lookup.data,
        cmap='nipy_spectral',
        symmetric_cmap=False,
        # colorbar=False,
    )

show_atlas("L")

In [None]:
hubs=(
    hub_data
    .sel(weight="sift2")
    .groupby("group")
    .mean()["hubness"]
    .sel(group="FEP")
    .squeeze()
    .reset_coords(drop=True)
)
avg_hubs = (
    avg_hub_data
    .sel(weight="sift2")
    ["hubness"]
    .squeeze()
    .reset_coords(drop=True)
)
data = get_data().drop_sel(subject=80)
nbs_arr = xr.concat(
    [
        xr.DataArray(
            scipy.io.loadmat("results/nbs/sift2_nbs.mat")["test_stat"],
            coords={
                "src": mds["node"].data,
                "dest": mds["node"].data,
            }
        ).expand_dims(weight=["sift2"]),
        xr.DataArray(
            scipy.io.loadmat("results/nbs/nbs.mat")["test_stat"],
            coords={
                "src": mds["node"].data,
                "dest": mds["node"].data,
            }
        ).expand_dims(weight=['avgFA']),
    ],
    dim="weight",
)


In [None]:
WEIGHT = "avgFA"
ctrl = data.mean(dim='subject').sel(weight=WEIGHT, drop=True)["adj"]
diff = data.assign(
    diff=data["adj"] - ctrl,
    nbs=nbs_arr
)

# def dim_map(func, ds, dims=None):
#     for 

def edge_hubness(hubs):
    return __.pipe(
        hubs['threshold'],
        __.map(lambda thresh: __.pipe(
            hubs['subject'],
            __.map(lambda subject: __.pipe(
                hubs.sel(threshold=thresh, subject=subject).to_numpy()[None, ...],
                lambda _: np.repeat(_, 2, axis=0),
                lambda _: np.add.outer(*_),
                lambda _: (
                    xr.DataArray(
                        _,
                        dims=("src", "dest"),
                        coords={"src": hubs["node"].data, "dest": hubs["node"].data}
                    )
                    .expand_dims({"threshold": [thresh.data], "subject": [subject.data]})
                ),
            )),
            lambda _: xr.concat(_, dim="subject")
        )),
        lambda _: xr.concat(_, dim="threshold")
    )

def edge_hub_diff(hubs):
    return __.pipe(
        hubs['threshold'],
        __.map(lambda thresh: __.pipe(
            hubs['subject'],
            __.map(lambda subject: __.pipe(
                hubs.sel(threshold=thresh, subject=subject).to_numpy()[None, ...],
                lambda _: np.repeat(_, 2, axis=0),
                lambda _: np.subtract.outer(*_),
                np.abs,
                lambda _: (
                    xr.DataArray(
                        _,
                        dims=("src", "dest"),
                        coords={"src": hubs["node"].data, "dest": hubs["node"].data}
                    )
                    .expand_dims({"threshold": [thresh.data], "subject": [subject.data]})
                ),
            )),
            lambda _: xr.concat(_, dim="subject")
        )),
        lambda _: xr.concat(_, dim="threshold")
    )

def mean_top(var: str, number: int, dim, *args, **kwargs):
    def inner(ds):
        dim_order = ds[var].dims
        axis = ds[var].get_axis_num(dim)
        sel = ds[var].fillna(0).argsort(axis=axis).isel({dim: slice(-number, None)})
        return __.pipe(
            ds,
            # __.filter(lambda _: _ == "diff"),
            __.filter(lambda _: set(ds[_].dims) == set(dim_order)),
            __.map(lambda _var: (_var, ds[_var].transpose(*dim_order))),
            __.starmap(lambda name, da: xr.DataArray(
                np.nanmean(np.take_along_axis(da.values, sel, axis=axis), axis=axis),
                coords = __.pipe(
                    da.coords.items(),
                    __.filter(lambda _: _[0] != dim),
                    dict,
                ),
                dims = __.pipe(
                    da.dims,
                    __.filter(lambda _: _ != dim),
                    tuple
                ),
                name=name,
            )),
            xr.merge,
        )
    return inner

In [None]:
import warnings
def quiet_std(x):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        return x.std("subject")


def edge_distribution_prefilter():
    return (
        data
        .groupby("group")
        .std()
        .mean(dim="src")
        .rename({"dest": "node", "adj": "std_prefilter"})
    )

def edge_distribution():
    std =  (
        diff
        .groupby("group")
        .median()
        ['adj']
        .std(dim='group')
    )
    mean = (
        diff
        .groupby("group")
        .median()
        ['adj']
        .mean(dim="group")
    )
    return (
        (std / mean)
        .mean(dim='src')
        .rename({"dest": "node"})#, "adj": "std"})
        .rename('std')
        
    )

def tstat():
    from scipy.stats import ttest_ind
    grouped = diff.groupby('group')
    hc = grouped['HC']['adj'].data
    fep = grouped['FEP']['adj'].data
    sub_dim = diff['adj'].dims.index('subject')
    t, p = ttest_ind(hc, fep, axis=sub_dim, nan_policy='omit', alternative='two-sided')
    red = diff.mean('subject')
    return xr.DataArray(t, dims=red.dims, coords=red.coords, name='tstat').mean('src').rename(dest='node')
    

def nbs_probability():
    return (
        diff
        ['nbs']
        .pipe(lambda ds: ds > 3)
        .mean(dim='src')
        .rename({"dest": "node"})#, "adj": "std"})
        .rename('nbs_prob')
        
    )

def abs_deviation():
    return (
        diff
        .groupby("group")
        .mean()
        ['diff']
        .pipe(abs)
        .mean(dim="src")
        .rename(dest="node")
        .rename("abs_diff")
    )

def positive_deviation():
    return (
        diff
        .groupby("group")
        .mean()
        ['diff']
        .pipe(lambda _: np.maximum(_, 0))
        .where(lambda ds: ds!= 0)
        .mean(dim="src")
        .pipe(abs)
        .rename(dest="node")
        .rename("pos_diff")
    )

def negative_deviation():
    return (
        diff
        .groupby("group")
        .mean()
        ['diff']
        .pipe(lambda _: np.minimum(_, 0))
        .where(lambda ds: ds!= 0)
        .mean(dim="src")
        .pipe(abs)
        .rename(dest="node")
        .rename("neg_diff")
    )

def node_wise_std():
    return (
        data
        .groupby('group')
        .std()
        ['adj']
        .mean(dim='src')
        .rename(dest="node")
        .rename("node_wise_std")
    )

def negative_deviation_prefilter():
    return (
        diff
        .assign(diff=lambda ds: np.minimum(ds['diff'], 0).where(lambda ds: ds!=0))
        .groupby("group")
        .mean()
        ['diff']
        # .where(lambda ds: ds!=0)
        .mean(dim="src")
        .pipe(abs)
        .rename(dest="node")
        .rename("neg_diff_prefilter")
    )

def edge_deviation(hub_weight='sift2'):
    hubs = (
        hub_data
        .sel(weight=hub_weight)
        ["hubness"]
    )
    return (
        diff
        .assign(
            edge_hubness=edge_hubness(hubs),
            edge_hub_diff=edge_hub_diff(hubs),
        )
    )
    
def mean_edge_deviation(hub_weight="sift2"):
    hubs = (
        hub_data
        .sel(weight=hub_weight)
        ["hubness"]
    )
    map_to_dest = lambda ds: ds.rename({"node": "src"}).expand_dims(dest=diff["dest"])
    return (
        diff
        .assign(
            edge_hubness=edge_hubness(hubs),
            dest_hubness=hubs.pipe(map_to_dest),
            dest_id=hubs["node"].pipe(map_to_dest),
        )
        .groupby("group")
        .mean(dim=['subject', 'src'])
        # .pipe(mean_top("diff", 100, dim="src"))
        .rename({"dest": "node"})
        .assign(
            src_hubness=hub_data
                .sel(weight=hub_weight)
                .groupby("group")
                .mean()
                ["hubness"],
            degree_rank=hub_data
                .sel(weight=hub_weight)
                .groupby("group")
                .mean()
                ["degree_rank"]
        )
        .assign(edge_check=lambda ds: ds["src_hubness"] + ds["dest_hubness"])
        .merge(edge_distribution())
        # .merge(negative_deviation())
        # .merge(negative_deviation_prefilter())
        # .merge(edge_distribution_prefilter())
        # .merge(abs_deviation())
        .merge(node_wise_std())
        .merge(tstat())
        .merge(nbs_probability())
        # .sel(group="FEP", threshold=5)
    )
f = mean_edge_deviation()

In [None]:
px.violin(
    f.merge(mds.drop_dims('subject')).sel(weight='avgFA', threshold=0, group='FEP')[['hemisphere', 'neg_diff_prefilter']].to_dataframe().reset_index(),
    x='hemisphere',
    y='neg_diff_prefilter'
)

In [None]:
px.scatter(
    edge_deviation()
    .groupby('group')
    .mean('subject')
    .mean('group')
    .sel(threshold=0, weight='avgFA')
    .to_dataframe(),
    
    x='edge_hubness',
    y='edge_hub_diff',
    height=1000,
)

In [None]:
df = (
    edge_deviation()
    .groupby('group')
    .mean('subject')
    .assign(adj_std=lambda ds: ds['adj'].std('group'))
    .mean('group')
    .sel(threshold=0, weight='avgFA')
    .to_dataframe()
)    

from sklearn.linear_model import LinearRegression

df.dropna(inplace=True)
X = df[["edge_hubness", "edge_hub_diff"]]
y = df["adj"]
reg = LinearRegression().fit(X, y)

In [None]:
reg.coef_

In [None]:
px.scatter_3d(
    edge_deviation()
    .sel(threshold=2, weight='avgFA')
    .groupby('group')
    .mean('subject')
    .assign(adj_std=lambda ds: ds['adj'].std('group') / ds['adj'].mean('group'))
    .mean('group')
    # .assign(adj=lambda ds: np.log(ds['adj']))
    .to_dataframe(),
    
    x='edge_hubness',
    y='edge_hub_diff',
    z='adj_std',
    color="adj",
    height=1000,
)

In [None]:
#| label: hubness-deviation-corr
#| fig-cap: |
#|   Correlation between hubness and mean negative deviation. Each row shows the
#|   analysis performed at a different graph threshold. Edge deviations are calculated
#|   relative to the HC mean weight, using average FA as the weight. Hubness scores are
#|   calculated using sift2-corrected streamline count as the weight. FEP deviations are
#|   shown in red; HC deviations are shown in blue. The R^2^ value for FEP deviations
#|   are shown above each graph.
def max_edge_deviation_fig(thresh):
    fig = px.scatter(
        f.sel(weight="avgFA", threshold=thresh, group=["HC"]).to_dataframe().reset_index(),
        x="src_hubness",
        y="std",
        color="group",
        trendline="ols",
    )
    fig.update_layout(
        # xaxis_title_text="Max Edge deviation",
        # title_text="Max Elevation",
        title_text="",
        # width=1000,
    )
    r_val = px.get_trendline_results(fig)[lambda df: df['group'] == 'HC'].iloc[0]["px_fit_results"].rsquared
    fig.add_annotation(
        x=0.05,
        y=1.01,
        xref="x domain",
        yref="y domain",
        text=f"R<sup>2</sup> = {r_val:.5f}",
        xanchor="left",
        yanchor="bottom",
        # xref="paper",
        # yref="paper",
        showarrow=False,
        font_size=14,
    )
    return fig


fig = plotly_grid(max_edge_deviation_fig, [0], vspacing=0.05)
# fig = max_edge_deviation_fig(4, "FEP")
fig.update_layout(
    height=400,
    width=1200,
)
D.Image(fig.to_image())

In [None]:
import nibabel as nib
from nilearn import datasets, plotting
def show_atlas(hem):
    infl_left = itx.one(layout.get(subject="001", suffix="inflated", hemi=hem))
    dparc = nib.load(
        itx.one(layout.get(subject="001", suffix="dparc", atlas="bn210", hemi=hem))
    )
    data = dparc.get_arrays_from_intent('NIFTI_INTENT_LABEL')[0]
    h = xr.concat(
        [
            xr.DataArray([0], dims=("node",), coords={"node": [0]}),
            f.sel(weight='avgFA', threshold=3, group="FEP")["std"],
        ],
        dim="node"
    )
    lookup = h.loc[data.data]
    fsaverage = datasets.fetch_surf_fsaverage()
    return plotting.view_surf(
        infl_left.path,
        lookup.data,
        cmap='nipy_spectral',
        symmetric_cmap=False,
        # colorbar=False,
    )
show_atlas("L")

In [None]:
avg_data = data.groupby("group").mean(dim="subject")
avg_diff = avg_data.assign(diff=(avg_data["adj"] - ctrl) / ctrl)
rng = np.random.default_rng(seed=2)


def get_window(args):
    diff, hubs, group, hubness, threshold = args
    ds = diff.sel(
        threshold=threshold,
        weight="avgFA",
        group=group,
    )
    nodes = hubs.sel(group=group, threshold=threshold).where(
        lambda ds: ds > hubness, drop=True
    )
    main = get_deviation(ds, nodes["node"].data)
    randoms = np.fromfunction(
        np.vectorize(lambda _: get_random_deviations(ds, len(nodes["node"]))),
        shape=tuple([1000]),
    )
    mean = randoms.mean()
    std = randoms.std()
    return pd.DataFrame(
        [
            {
                "group": group.item(),
                "hubness": hubness.item(),
                "threshold": threshold,
                "deviation": main.data,
                "rand_deviation": mean,
                "rand_deviation_std": std,
            }
        ]
    )


def get_deviation(ds, nodes):
    return (
        ds.sel(dest=nodes, src=nodes)["diff"]
        .pipe(lambda _: np.minimum(_, 0))
        .where(lambda ds: ds != 0)
        .pipe(abs)
        .mean()
    )


def get_random_deviations(ds, num):
    nodes = rng.choice(avg_hubs["node"], num, replace=False)
    return get_deviation(ds, nodes)


@nb_cache("windowed")#, reset_cache=True)
def get_windowed():
    params = lambda: it.product(
        [avg_diff],
        [avg_hubs],
        avg_hubs["group"],
        avg_hubs.sel(group="FEP", threshold=threshold).pipe(np.sort)[::20],
        range(9),
    )
    with mp.Pool() as pool:
        threshold = 1
        return pd.concat(
            __.pipe(
                pool.imap(
                    get_window,
                    params(),
                ),
                lambda _: tqdm.tqdm(_, total=itx.ilen(params())),
                list,
            )
        )


windowed = (
    get_windowed()
    .assign(deviation_std=0)
    .set_index(["group", "hubness", "threshold"])
    .to_xarray()
)

In [None]:
#| label: hubness-window
#| fig-cap: |
#|   Deviation within hubness core. X-axis tracks the minimum hubness of the core being
#|   considered. The blue line tracks the total negative deviation within the core, the
#|   red line tracks the average negative deviation within random subnetworks of the
#|   same size as the core. All deviations are shown as absolute negative deviations.
_windowed = windowed.sel(group="FEP")
means = (
    _windowed
    .drop(['rand_deviation_std', 'deviation_std'])
    .to_array(dim="var", name="deviation")
)
std = (
    _windowed
    .drop(['rand_deviation', 'deviation'])
    .rename({'rand_deviation_std': 'rand_deviation', 'deviation_std': 'deviation'})
    .to_array(dim='var', name="std")
)


fig = error_line(
    xr.merge([means, std]).sel(threshold=5).to_dataframe().reset_index(),
    x="hubness",
    y="deviation",
    color="var",
    err="std",
    markers=True,
)

fig.update_layout(
    showlegend=False,
    # title_text="Global Graph properties",
    title = dict(
        # text=f'<b>{titleize(fig.layout["title"]["text"])}</b>',
        text="",
        font_family="roboto",
        font_color="#58595b",
        font_size=18,
        xanchor="center",
        x = 0.5,
        xref="paper",
    ),
    height=400,
    width=800,
    # margin=dict(
    #     t=50,
    #     b=50,
    #     l=50,
    #     r=50,
    # ),
    template="seaborn"
)
D.Image(fig.to_image(format='png'))

In [None]:
import statsmodels.api as sm

d = mean_edge_deviation()
__.pipe(
    # it.product(diff_by_hubs['threshold'], ["HC", "FEP"]),
    diff.sel(subject=diff["group"] == "FEP")["subject"],
    __.filter(lambda _: _!=80),
    __.map(lambda subj : __.pipe(
        d.sel(subject=subj, threshold=5)
        .to_dataframe(),

        lambda _: sm.OLS(_["src_hubness"], sm.add_constant(_["diff"]), missing='drop').fit(),
        lambda _: pd.DataFrame(
            [[*_.params, _.rsquared]],
            columns=["const", "hubness", "rsquared"],
            # index=pd.Index([(group, int(thresh))], name=("group", "threshold"))
            index=pd.Index([subj], name="subject")
        )
    )),
    pd.concat,
    # pd.DataFrame.sort_index,
)


In [None]:
# fig = px.scatter(
#     grouper
#     .median()
#     .assign(std=grouper.map(quiet_std)["diff"])
#     .stack({"edge": ["src", "dest"]})
#     .sel(group="FEP")
#     .to_dataframe()
#     ,
#     x="diff",
#     y="hubness",
#     # trendline="ols",
#     color="std",
# )


In [None]:
data = xr.merge([
    pd.read_csv(
        itx.one(layout_get(desc="avgFA", suffix="richclub", subject="001")), index_col=0
    )
    .assign(subject=1)
    .set_index(["subject", "threshold", "k"], drop=True)
    .to_xarray(),

    get_subj_metadata().to_xarray()
])
phi = __.pipe(
    dict(
        phi=data["phi"],
        std=data["phi"],
    ),
    lambda _: xr.merge([_]).to_dataframe().reset_index()
)
error_line(phi, x="k", y="phi", err="std", color="threshold", markers=True)

In [None]:

data = xr.merge([
    pd.concat(
        pd.read_csv(item, index_col=0)
        .assign(subject=int(item.entities["subject"]))
        .set_index(["subject", "threshold", "k"], drop=True)

        for item in
        layout_get(desc="avgFA", suffix="richclub")
    ).to_xarray(),
    get_subj_metadata().to_xarray()
])
grouper = data.sel(threshold=4).groupby("group")
phi = __.pipe(
    dict(
        mean=grouper.mean("subject")["phi"],
        std=grouper.std("subject")["phi"],
    ),
    lambda _: xr.merge([_]).to_dataframe().reset_index()
)
error_line(phi, x="k", y="mean", err="std", color="group", markers=True)
