In [None]:
import scipy.io
import numpy as np
import networkx as nx
from networkx.algorithms import community
import pandas as pd
import itertools as it
import copy
import re

from notebooks.utils import filter_logile, hex_to_rgb, get_lut, lut_label, titleize, distribution_plot, figures_to_html, plotly_tabulate, NbCache


In [None]:
G = nx.from_numpy_matrix(stats)
for i, j in G.edges:
    if np.isnan(G.edges[i, j]["weight"]):
        G.remove_edge(i, j)
components = community.greedy_modularity_communities(G, weight="weight")

In [None]:
order = list(it.chain.from_iterable(components))

In [None]:
def np_str_join(__delim: str, /, arrays, axis=None):
    if axis is None:
        return __delim.join(np.flatten(arrays))
    if axis:
        return np_str_join(__delim, np.moveaxis(arrays, axis, 0), axis=0)
    arrays = iter(arrays)
    try:
        first = next(arrays)
    except StopIteration:
        return np.array()
    try:
        second = next(arrays)
    except StopIteration:
        return first
    mod = np.char.mod(f"%s{__delim}", first)
    return np_str_join(__delim, it.chain([np.char.add(mod, second)], arrays), axis=0)

def autopair_labels(labels):
    return (
        np.array(np.meshgrid(labels, labels))
            .T
            .reshape((len(labels), len(labels), 2))
            .astype(np.str_)
    )


In [None]:
def get_metadata():
    return (
        pd.concat([
            copy.copy(
                orig := (
                    raw := pd.read_csv("resources/brainnetome-regions.csv")
                )
                .assign(
                    **dict(
                        raw["Modified cyto-architectonic"]
                        .str.split(r',\s?', expand=True)
                        .rename({0: "Name", 1: "Long Name"}, axis=1)
                        .drop(columns=[2])
                    ),
                    **dict(
                        raw["Gyrus"]
                        .str.split(r',\s?', expand=True)
                        .rename({0: "Gyrus Abbr", 1: "Gyrus"}, axis=1)
                    )
                )
                .drop(
                    columns=[
                        "Modified cyto-architectonic",
                        "Left and right hemispheres"
                    ]
                )
            )
            .drop(columns=["rh.MNI (X, Y, Z)", "Label ID.R"])
            .rename(columns={"lh.MNI (X,Y,Z)": "MNI", "Label ID.L": "Label ID"})
            .assign(hemisphere="R"),

            orig
            .drop(columns=["lh.MNI (X,Y,Z)", "Label ID.L"])
            .rename(columns={"rh.MNI (X, Y, Z)": "MNI", "Label ID.R": "Label ID"})
            .assign(hemisphere="L")
        ])
        .astype({"Label ID": int})
        .assign(
            Lobe=lambda df: df["Lobe"].str.strip()
        )
        .set_index("Label ID")
        .sort_index()
        .reindex(index=np.array(order) + 1)
        # .sort_values(["hemisphere"])
    )

In [None]:
class Data:
    def __init__(self, raw, metadata):
        self.hems = autopair_labels(metadata.sort_index()["hemisphere"])
        self.metadata = metadata
        self.raw = raw

    @classmethod
    def from_matlab(cls, matrix: str, metadata):
        stats = scipy.io.loadmat(matrix)["test_stat"]
        return cls(stats, metadata)

    def update(self, __new):
        return self.__class__(__new, self.metadata)

    def threshold(self, threshold):
        return self.update(np.ma.masked_less(self.raw, threshold))

    def _mask(self, mask):
        return self.update(np.ma.masked_where(mask, self.raw))

    @property
    def l_to_l(self):
        return self.hems_same._mask(self.hems[:, :, 0] == "L")

    @property
    def hems_same(self):
        return self._mask(self.hems[:,:,0] != self.hems[:,:,1])

    @property
    def hems_different(self):
        return self._mask(self.hems[:,:,0] == self.hems[:,:,1])

    @property
    def masked(self):
        stat_index = self.metadata.index - 1
        return self.raw.filled(np.NaN)[np.ix_(stat_index, stat_index)]

In [None]:
import plotly.graph_objects as go
import plotly.express as px

df = get_metadata()
def get_names(df):
    return df["Name"] + "_" + df["hemisphere"].str.lower()

def get_labels(df, fields):
    labels = np.dstack([
        np_str_join(" ➜ ", autopair_labels(df[field]), axis=2) for field in fields
    ])
    template = "<br>".join(f"<b>{name}:</b> %{{customdata[{i}]}}" for i, name in enumerate(fields))
    return labels, template


labels, template = get_labels(
    df,
    [
        "hemisphere",
        "Lobe",
        "Gyrus",
        "Long Name",
    ]
)

data = Data.from_matlab('results/nbs/nbs.mat', df).threshold(3)
# stats[get_hem_mask(df)] = np.NAN
# fig = px.imshow(
#     stats[np.ix_(stat_index, stat_index)],
#     x=get_names(df),
#     y=get_names(df),
#     width=800,
#     height=800
# )
# fig.update_traces(
#     customdata=labels,
#     hovertemplate="<br>".join([
#         "(%{x}, %{y})",
#         template,
#         "T-value: %{z}",
#     ])
# )

fig = go.Figure()
fig.add_traces([
    # go.Heatmap(
    #     z = background[np.ix_(stat_index, stat_index)],
    #     # hoverinfo="none",
    # ),
    go.Heatmap(
        z = data.hems_different.masked, 
        x=get_names(df),
        y=get_names(df),
        hoverongaps=False,
        customdata=labels,
        hovertemplate="<br>".join([
            "(%{y}, %{x})",
            template,
            "T-value: %{z}",
        ]),
        colorscale=[[0, "rgba(0, 0, 255, 0.2)"], [1, "#0000FF"]],
        zmin=3,
        zmax=5,
    ),
    go.Heatmap(
        z = data.l_to_l.masked, 
        x=get_names(df),
        y=get_names(df),
        hoverongaps=False,
        customdata=labels,
        hovertemplate="<br>".join([
            "(%{y}, %{x})",
            template,
            "T-value: %{z}",
        ]),
        colorscale=[[0, "rgba(0,255,0,0.2)"], [1, "#00FF00"]],
        zmin=3,
        zmax=5,
        colorbar=dict(
            x=1.1,
        )
    ),

])
fig.update_layout(
    width=1000,
    height=1000,
    yaxis={
        "autorange": "reversed"
    }
)
fig.show()