In [None]:
ATLAS = "bn246"
WEIGHT = "minFA"

In [None]:
import pickle
import itertools as it
import numpy as np
import pandas as pd
import networkx as nx
import os
import subprocess as sp
from pathlib import Path
import copy
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import more_itertools as itx
import shutil
from typing import Any, List
import plotly.graph_objects as go
import ipywidgets as widgets

In [None]:
categories = [
    0,
    1,
    1,
    1,
    1,
    1,
    3,
    2,
    2,
    3,
    1,
    1,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    3,
    3,
    2,
    2,
    1,
    1,
    2,
    2,
    2,
    1,
    1,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    2,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    1,
    1,
    2,
    1,
    1,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    4,
    1,
    1,
    2,
    2,
    2,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    4,
    4,
    2,
    3,
    3,
    2,
    2,
    2,
    2,
    4,
    3,
    3,
    3,
    3,
    4,
    1,
    3,
    4,
    2,
    0,
    4,
    1,
    1,
    3,
    1,
    1,
    2,
    3,
    2,
    4,
    4,
    4,
    2,
    1,
    0,
    1,
    1,
    2,
    3,
]

In [None]:
cats = {
    1: "HC",
    2: "FEP",
    3: "Treatment 3+ yr",
    4: "High risk"
}

In [None]:
from bids import BIDSLayout

# layout = BIDSLayout("results/prepdwi_recon", validate=False)
layout = BIDSLayout("../..", derivatives=True, database_path="../../.pybids")

### Utility Fuctions

In [None]:
def filter_logile(matrix, bin: int, num_bins: int = 10):
    if bin >= num_bins:
        raise ValueError("bin must be less then num_bins")
    masked = np.ma.masked_equal(matrix, 0)
    log = np.ma.log10(masked)
    threshold = ((log.max() - log.min()) * bin / num_bins) + log.min()
    cp = copy.deepcopy(matrix)
    cp[log <= threshold] = 0
    return cp

def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

def get_lut(path):
    with open(path) as f:
        lines = [line.strip().split() for line in f.readlines()]
    return {int(key): val for key, val in zip(*list(zip(*lines))[0:2])}

def lut_label(data, path):
    lut = get_lut(path)
    return pd.DataFrame(data).rename(index=lut, columns=lut)


def titleize(label):
    return label.replace("_", " ").capitalize()

import plotly.express as px
import plotly.graph_objs as go
def distribution_plot(df, x, y: str):

    std_col = y+"_std"
    grouper = df.groupby(["category", x])
    ddf = pd.concat([
        grouper.mean(),
        grouper.std().rename({y: std_col}, axis="columns"),
    ], axis=1).reset_index().set_index("category")
    fig = px.line(
        ddf,
        x=x,
        y=y,
        color=ddf.index,
        width=800,
        height=600,
        labels={
            "x": "Nodes sorted by increasing degree",
            "node_size": "Node Size (# triangles)",
            "betweenness": "Betweenness",
            "degree": "Degree"
        },
        title=f"{titleize(y)} distribution",
    )
    buttons = []
    num_traces = len(ddf.index.unique())
    for i, cat in enumerate(ddf.index.unique()):
        fig.add_traces([
            go.Scatter(
                x=ddf.loc[cat, x],
                y=ddf.loc[cat, std_col]+ddf.loc[cat,y],
                mode="lines",
                line=dict(width=0),
            ),
            go.Scatter(
                x=ddf.loc[cat, x],
                y=ddf.loc[cat,y]-ddf.loc[cat, std_col],
                mode="lines",
                line=dict(width=0),
                fill='tonexty',
                fillcolor=f'rgba{(*hex_to_rgb(px.colors.qualitative.Plotly[i]), 0.2)}'
            )
        ])
        buttons.append({
            "method": 'restyle',
            "visible": True,
            "label": cat,
            "args": [{
                "visible": False
            }, [i, num_traces + i*2, num_traces + i*2 + 1]],
            "args2": [{
                "visible": True,
            }, [i, num_traces + i*2, num_traces + i*2 + 1]]
        })
    fig.update_layout(
        margin=dict(l=50, r=50, t=50, b=50),
        showlegend=False,
        updatemenus=[
            dict(
                type="buttons",
                direction="right",
                x=1,
                y=-0.2,
                showactive=True,
                buttons=buttons,
            )
        ]
    )
    return fig

def figures_to_html(figs, filename="dashboard.html"):
    Path(filename).parent.mkdir(exist_ok=True)
    with open(filename, 'w') as dashboard:
        dashboard.write("<html><head></head><body>" + "\n")
        for fig in figs:
            inner_html = fig.to_html().split('<body>')[1].split('</body>')[0]
            dashboard.write(inner_html)
        dashboard.write("</body></html>" + "\n")

def plotly_tabulate(figs):
    figs = list(figs)
    titles = [fig.layout["title"]["text"] for fig in figs]
    tab = widgets.Tab()
    tab.children = [go.FigureWidget(fig) for fig in figs]
    for i, title in enumerate(titles):
        tab.set_title(i, title)
    return tab

class NbCache:
    def __init__(self, *indices: str, root="."):
        self.indicies = indices
        self.root = root

    def __call__(self, name, reset_cache=False):
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        cache_file = (self.cache_dir / name).with_suffix(".pyc")
        def do_cache(func, *args, **kwargs):
            if cache_file.exists() and not reset_cache:
                with cache_file.open('rb') as f:
                    return pickle.load(f)
            result = func(*args, **kwargs)
            with cache_file.open('wb') as f:
                pickle.dump(result, f)
            return result
        def wrapper(func) -> Any:
            def inner(*args, **kwargs):
                return do_cache(func, *args, **kwargs)
            return inner
        return wrapper

    @property
    def cache_dir(self):
        return Path(self.root, ".ipynb_cache", *self.indicies)

    def migrate(self, new: "NbCache", dry: bool = False):
        new.cache_dir.mkdir(parents=True, exist_ok=True)
        for path in self.cache_dir.iterdir():
            if path.is_dir():
                continue
            dest = new.cache_dir / path.name
            if dry:
                print(path, "->", dest)
                continue
            shutil.move(str(path), dest)

nb_cache = NbCache(ATLAS, WEIGHT)


### Single Subject

Here, test a single subject and view the connection density as a histogram

In [None]:
c = np.genfromtxt(
    layout.get(
        subject="034", datatype="dwi", atlas=ATLAS, suffix="connectome", desc=WEIGHT
    )[0].path, 
    delimiter=","
)

In [None]:
filtered = filter_logile(c, 3)
np.fill_diagonal(filtered, 0)
G = nx.from_numpy_matrix(filtered)
import math
for edge in G.edges:
    G.edges[edge]["distance"] = -math.log10(G.edges[edge]["weight"])

In [None]:
lut = get_lut("resources/BN_Atlas_freesurfer/BN_Atlas_246_LUT.txt")
df = pd.DataFrame(filtered).rename(index=lut, columns=lut)
plt.figure(figsize=(20,20))
sns.heatmap(df, cmap="rocket", square=True)

In [None]:
# Weight distribution plot
bins = np.arange(np.sqrt(len(np.concatenate(filtered))))
bins = (bins - np.min(bins))/np.ptp(bins)
fig, axes = plt.subplots(1,2, figsize=(15,5))

# Distribution of raw weights
rawdist = sns.histplot(filtered.flatten(), bins=bins, kde=False, ax=axes[0])
rawdist.set(xlabel='Correlation Values', ylabel = 'Density Frequency')

# Probability density of log10
log10dist = sns.histplot(np.log10(filtered).flatten(), kde=False, ax=axes[1])
log10dist.set(xlabel='log(weights)')

### Multiple Subjects

Loop through all subjects and gather various metrics

In [None]:
def subject_graphs(filter_level = 1, drop = []):
    for bidsfile in layout.get(datatype="dwi", suffix="connectome", atlas=ATLAS, desc=WEIGHT):
        sub = int(bidsfile.entities['subject'])
        if len(categories) <= sub:
            print(f"Dropped {sub} (out of range)")
            continue
        cat = categories[sub]
        if cat not in cats:
            print(f"Dropped {sub} (no diagnosis assigned)")
            continue
        c = np.genfromtxt(bidsfile.path, delimiter=",")
        np.fill_diagonal(c, 0)
        filtered = filter_logile(c, filter_level)
        df = lut_label(filtered, "resources/BN_Atlas_freesurfer/BN_Atlas_246_LUT.txt")
        df.drop(index="Unknown", columns="Unknown", inplace=True)
        dropped = drop(sub) if callable(drop) else drop
        df.drop(index=dropped, columns=dropped, inplace=True)
        G = nx.from_pandas_adjacency(df)
        for edge in G.edges:
            G.edges[edge]["distance"] = 1/G.edges[edge]["weight"]
        yield sub, cat, G, dropped


In [None]:
def subject_properties(sub, cat, G, drop_regions = []):
        try:
            return {
                "subject": sub,
                "category": cats[cat],
                "degree":np.mean([*zip(*G.degree)][1]),
                "num_regions": len(G.nodes),
                "dropped_regions": list(drop_regions),
                "num_connected_comps": nx.number_connected_components(G),
                "largest_connected_comp": len(max(nx.connected_components(G), key=len)),
                "density": nx.density(G),
                "transitivity": nx.transitivity(G),
                "efficiency": nx.global_efficiency(G)
            }
        except KeyError as err:
            print(sub)
            raise err

def subject_df(drop_regions = []):
    rows = []
    for sub, cat, G, dropped in subject_graphs(drop=drop_regions):
        try:
            rows.append(subject_properties(sub, cat, G, dropped))
        except KeyError as err:
            print(sub)
            raise err

    return pd.DataFrame(rows)

In [None]:
import plotly.express as px
@nb_cache("subject_full_df")
def subject_full_df():
    return subject_df()

def subject_distributions(col):
    df = subject_full_df()
    fig = px.violin(
        df,
        x="category",
        color="category",
        y=col,
        points="outliers",
        width=584,
        height=400,
        labels={
            "num_regions": "# Regions",
            "category": "Group"
        },
        title=titleize(col),
        hover_data=["subject"]
    )
    fig.update_layout(
        margin=dict(l=50, r=50, t=50, b=50),
        showlegend=False
    )
    return fig
cols = [
    "transitivity",
    "efficiency",
    "density",
    "num_connected_comps",
    "largest_connected_comp",
    "degree",
]
plotly_tabulate([subject_distributions(col) for col in cols])
# figures_to_html([
#     subject_distributions(col)
#     for col in [
#         "transitivity",
#         "efficiency",
#         "density",
#         "num_connected_comps",
#         "largest_connected_comp",
#         "degree",
#     ]
# ], "pages/bn246/subject_distributions.html")

In [None]:
@nb_cache("nodal_properties")
def nodal_properties():
    rows = []
    for sub, cat, G, _ in subject_graphs(1):
        b = nx.betweenness_centrality(G, weight="distance")
        for node in G:
            rows.append({
                "node": node,
                "subject": sub,
                "category": cats[cat],
                "degree": G.degree[node],
                "clust_coeff": nx.clustering(G, nodes=node),
                "path_length": np.mean(list(nx.shortest_path_length(G, source=node, weight="distance").values())),
                "betweenness": b[node],
            })
    df = pd.DataFrame(rows)
    return df

def property_rank(df, columns=[], inverse_columns=[]):
    names = df.index.names[:-1]
    def rank(df, column, inverse=False):
        df.sort_index(inplace=True)
        sort = df.sort_values(by=[*names, column], ascending=not inverse)
        for cat in df.reset_index().set_index(names).index.unique():
            nodes = sort.loc[cat].index
            for i, node in enumerate(nodes):
                df.loc[
                    (*itx.always_iterable(cat), node), f"{column}_rank"
                ] = i/len(nodes)
    for column in columns:
        print(f"Ranking {column}")
        rank(df, column)
    for column in inverse_columns:
        print(f"Ranking {column}")
        rank(df, column, inverse=True)
    return df

def hubness(df, threshold = None, ivars=["category"]):
    grouped = df.groupby([*ivars, "node"]).mean()
    cols = ["betweenness", "degree"]
    inv_cols = ["path_length", "clust_coeff"]
    ranked = property_rank(grouped, columns=cols, inverse_columns=inv_cols)
    ranked["hubness"] = 0
    for col in it.chain(cols, inv_cols):
        if threshold is None:
            ranked["hubness"] += ranked[f"{col}_rank"]
        else:
            ranked["hubness"] += (ranked[f"{col}_rank"] > threshold).astype(int)
    return ranked
    

In [None]:
cols= [
    "betweenness",
    "clust_coeff",
    "path_length",
    "degree",
]
indexed = property_rank(nodal_properties().set_index(["category", "subject", "node"]), cols)
plotly_tabulate(distribution_plot(indexed, x=col+"_rank", y=col) for col in cols)
# figures_to_html(
#     [
#         distribution_plot(indexed, x=col+"_rank", y=col) 
#         for col in cols
#     ],
#     filename="pages/bn246/nodal_distributions.html"
# )

In [None]:
def hub_disruption_index(group, cols):
    df = nodal_properties()
    indexed = df.set_index(["category"])
    hc = indexed.loc["HC"]
    avg = hc.groupby("node").mean()[cols]
    group_values = indexed.loc[group].reset_index().pivot(columns="node", index="subject", values=cols)
    delta = group_values - avg
    delta = pd.DataFrame(delta.stack(), columns=["delta"])
    delta["avg"] = avg.reindex(delta.index, level="node")
    delta["Delta Total"] = delta.reset_index().groupby("subject").sum()["delta"].reindex(delta.index, level="subject")
    return delta.reset_index()

In [None]:
cols= [
    "betweenness",
    "clust_coeff",
    "path_length",
    "degree",
]
plotly_tabulate(
    px.scatter(
        hub_disruption_index("FEP", col),
        x="avg",
        y="delta",
        color="Delta Total",
        trendline="ols",
        title=titleize(col),
        hover_data=["subject", "node"]
    )
    for col in cols
)

In [None]:
@nb_cache("hubness")
def get_hubs():
    return hubness(nodal_properties(), ivars=["category", "subject"])
hubs = get_hubs()

In [None]:
# cols = ["degree", "clust_coeff", "path_length", "betweenness"]

plt.figure(figsize=(30, 50))
plt.subplot(1, 3, 1)
table1 = (
    hubs
    .loc["HC"]
    .reset_index()
    .pivot(index="node", columns="subject", values="hubness")
)
sns.heatmap(table1, cmap="viridis", square=True, cbar=False)
plt.subplot(1, 3, (2,3))
table2 = (
    hubs
    .loc["FEP"]
    .reset_index()
    .pivot(index="node", columns="subject", values="hubness")
)
sns.heatmap(table2, cmap="viridis", square=True, cbar=False)

In [None]:
def get_drop_func(i, hub_df):
    indexed = (
        hub_df
        .reset_index()
        .set_index(["subject", "node"])
        .sort_values(["subject", "hubness"], ascending=False)
    )
    def drop_func(sub):
        return indexed.loc[sub].index[:i]
    return drop_func
    

def _get_subject_df(x):
    return subject_df(drop_regions=get_drop_func(x, hubs))
import multiprocessing as mp
@nb_cache("attack_analysis")
def attack_analysis():
    with mp.Pool(processes=32) as pool:
        dfs = pool.map(
            _get_subject_df,
            range(hubs.reset_index().groupby("subject").count().max()[0])
        )
    # for i in range(hubs.reset_index().groupby("subject").count().max()[0]):
    #     dfs.append(subject_df(drop_regions = get_drop_func(i, hubs)))
    return pd.concat(dfs, axis=0)

In [None]:
df = attack_analysis()
df["num_dropped"] = df["num_regions"].max() - df["num_regions"]
plotly_tabulate(
    distribution_plot(df, x="num_dropped", y=col) 
    for col in [
        "transitivity",
        "efficiency",
        "density",
        "num_connected_comps",
        "largest_connected_comp",
        "degree",
    ]
)

In [None]:
# cols = ["degree", "clust_coeff", "path_length", "betweenness"]
df = nodal_properties()

table = (
    hubness(df)
    .reset_index()
    .pivot(index="node", columns="category", values="hubness")
    .reindex(columns=["HC", "FEP", "Treatment 3+ yr", "High risk"])
)
plt.figure(figsize=(20,40))
sns.heatmap(table, cmap="viridis", square=True)