In [None]:
from __future__ import annotations
import pickle
import itertools as it
import numpy as np
import pandas as pd
import xarray as xr
import networkx as nx

import os
import subprocess as sp
from pathlib import Path
import copy
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import more_itertools as itx
import multiprocessing as mp
import tqdm
import attrs
# import sklearn
import shutil
import plotly.express as px
import functools as ft
from typing import Any, List
import plotly.graph_objects as go
import ipywidgets as widgets
import IPython.display as D
import imageio.v3 as iio

In [None]:
os.chdir("..")

In [None]:
def get_subj_metadata():
    return (
        pd.read_csv("../../part", sep="\t")
        .rename(columns={"TOPSY ID": "subject"})
        .set_index("subject")
        .assign(group=lambda df: df["Patient Cat"].map(cats))
    )

cats = {1: "HC", 2: "FEP", 3: "Treatment 3+ yr", 4: "High risk"}

from bids import BIDSLayout

def get_participants(participant_file):
    with open(participant_file) as f:
        subs = [match.group(1) for l in f.read().splitlines() if (match := re.match(r'sub-(.*)', l)) and match.group(1) != "080"]
    return subs

layout = BIDSLayout("../..", derivatives=True, database_path="../../.pybids")
layout_get = ft.partial(
    layout.get,
    subject=get_participants('../../derivatives/snakedwi-0.1.0/participants.tsv'),
    datatype="dwi",
    suffix="connectome",
    atlas="bn246",
)

### Utility Fuctions

In [None]:
from notebooks.utils import hex_to_rgb, get_lut, lut_label, titleize, distribution_plot, figures_to_html, plotly_tabulate, NbCache, underscore as __, error_line
from notebooks.adjacency_matrix import group_outer, AdjacencyMatrix, filter_edges, community_sort
from notebooks.bn_metadata import read_metadata
from notebooks.graph_metrics import global_efficiency, local_efficiency
from notebooks.plotly_helpers import plotly_grid, matplotlib_to_plotly
nb_cache = NbCache("connectomics")

### Single Subject

Here, test a single subject and view the connection density as a histogram

In [None]:
adj = community_sort(
    Subject.from_bids_entry(itx.one(layout_get(subject="001", desc="avgFA")))
    .adj
    .mask_where(filter_logile(0))
)

adj = adj.with_metadata(adj.metadata.sort_values("hemisphere"))
adj.plot()

#### Connection Density

Density of connections as threshold increases

In [None]:
# c = np.genfromtxt(
#     layout.get(
#         subject="001", datatype="dwi", atlas=ATLAS, suffix="connectome", desc=WEIGHT
#     )[0].path, 
#     delimiter=","
# )
def connection_density():
    densities = []
    for entry in layout_get(desc="sift2"):
        subj = Subject.from_bids_entry(entry)
        if not subj:
            continue
        for i in range(10):
            adj = (
                subj
                .adj
                .mask_where(filter_logile(i))
            )

            densities.append((subj.id, i, nx.density(adj.graph)))
    return pd.DataFrame(
            densities,
            columns=["subject", "threshold", "density"]
        ).set_index(["subject", "threshold"])
densities = connection_density()

In [None]:
ds = xr.merge([
    get_subj_metadata().to_xarray(),
    densities.to_xarray(),
])

error_line(
    __.pipe(
        dict(
            density=ds.groupby("group").mean("subject")["density"],
            std=ds.groupby("group").std("subject")["density"],
        ),
        lambda _: xr.merge([_]).to_dataframe().reset_index()
    ),
    x="threshold",
    y="density",
    color="group",
    err="std",

)

#### Weight distribution

In [None]:

# Weight distribution plot
bins = __.pipe(
    adj.filled,
    np.concatenate,
    len,
    np.sqrt,
    np.arange,
)
# bins = (bins - np.min(bins))/np.ptp(bins)
fig, axes = plt.subplots(1,2, figsize=(15,5))

# Distribution of raw weights
rawdist = sns.histplot(adj.masked.flatten(), bins=bins, kde=False, ax=axes[0])
rawdist.set(xlabel='Correlation Values', ylabel = 'Density Frequency')

# Probability density of log10
log10dist = sns.histplot(np.log10(adj.masked).flatten(), kde=False, ax=axes[1])
log10dist.set(xlabel='log(weights)')

#### Edge Mass Function

The total weight held by all edges of a given weight bin

In [None]:

def edge_mass_function():
    cum_weight = []
    for entry in layout_get(desc="sift2"):
        subj = Subject.from_bids_entry(entry)
        if not subj:
            continue
        adj = (
            subj
            .adj
            .mask_where(filter_logile(0))
        )
        arr = adj.masked.filled(0)
        bins = __.pipe(
            # arr,
            # np.concatenate,
            # len,
            # np.sqrt,
            246,
            np.arange,
            lambda _: np.divide(_, np.max(_)),
            lambda _: np.multiply(_, np.ptp(arr)),
            lambda _: np.add(_, np.min(arr))
        )
        digitized = np.digitize(arr, bins)
        for i in np.unique(digitized):

            cum_weight.append((subj.id, i, arr[digitized == i].sum()))
    return pd.DataFrame(
        cum_weight,
        columns=["subject", "bin", "weight"]
    ).set_index(["subject", "bin"])
cum_weight = edge_mass_function()


In [None]:
ds = xr.merge([
    get_subj_metadata().to_xarray(),
    cum_weight.to_xarray()
])

error_line(
    __.pipe(
        dict(
            weight=ds.mean("subject")["weight"],
            std=ds.std("subject")["weight"],
        ),
        lambda _: xr.merge([_]).to_dataframe().reset_index()
    ),
    x="bin",
    y="weight",
    # color="group",
    err="std",

)