In [None]:
import scipy.io
import numpy as np
import networkx as nx
from networkx.algorithms import community
import pandas as pd
import itertools as it
import copy
import re
from colour import Color
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import imageio.v3 as iio
import IPython.display as D

from notebooks.utils import filter_logile, hex_to_rgb, get_lut, lut_label, titleize, distribution_plot, figures_to_html, plotly_tabulate, NbCache, listify
from notebooks.adjacency_matrix import AdjacencyMatrix, group_outer

In [None]:
def np_str_join(__delim: str, /, arrays, axis=None):
    if axis is None:
        return __delim.join(np.flatten(arrays))
    if axis:
        return np_str_join(__delim, np.moveaxis(arrays, axis, 0), axis=0)
    arrays = iter(arrays)
    try:
        first = next(arrays)
    except StopIteration:
        return np.array()
    try:
        second = next(arrays)
    except StopIteration:
        return first
    mod = np.char.mod(f"%s{__delim}", first)
    return np_str_join(
        __delim,
        it.chain([np.char.add(mod, second)], arrays),
        axis=0,
    )


def autopair_labels(labels):
    return (
        np.array(np.meshgrid(labels, labels))
        .T
        .reshape((len(labels), len(labels), 2))
    )

In [None]:
def get_metadata():
    return (
        pd.concat([
            copy.copy(
                orig := (
                    raw := pd.read_csv("resources/brainnetome-regions.csv")
                )
                .assign(
                    **dict(
                        raw["Modified cyto-architectonic"]
                        .str.split(r',\s?', expand=True)
                        .rename({0: "Name", 1: "Long Name"}, axis=1)
                        .drop(columns=[2])
                    ),
                    **dict(
                        raw["Gyrus"]
                        .str.split(r',\s?', expand=True)
                        .rename({0: "Gyrus Abbr", 1: "Gyrus"}, axis=1)
                    )
                )
                .drop(
                    columns=[
                        "Modified cyto-architectonic",
                        "Left and right hemispheres"
                    ]
                )
            )
            .drop(columns=["rh.MNI (X, Y, Z)", "Label ID.R"])
            .rename(columns={"lh.MNI (X,Y,Z)": "MNI", "Label ID.L": "Label ID"})
            .assign(hemisphere="R"),

            orig
            .drop(columns=["lh.MNI (X,Y,Z)", "Label ID.L"])
            .rename(columns={"rh.MNI (X, Y, Z)": "MNI", "Label ID.R": "Label ID"})
            .assign(hemisphere="L")
        ])
        .astype({"Label ID": int})
        .assign(
            Lobe=lambda df: df["Lobe"].str.strip()
        )
        .set_index("Label ID")
        .sort_index()
        # .sort_values(["hemisphere"])
    )

In [None]:
class NewData(AdjacencyMatrix):

    @classmethod
    def from_matlab(cls, matrix: str, metadata):
        stats = scipy.io.loadmat(matrix)["test_stat"]
        return cls(stats, metadata)

    @property
    def community_sort(self):
        components = community.greedy_modularity_communities(
            self.graph, weight="weight"
        )
        order = list(it.chain.from_iterable(components))
        return self.with_metadata(self.metadata.reindex(index=np.array(order)))

    @property
    def l_to_l(self):
        return self.mask_where_meta(lambda df: ~np.all(df["hemisphere"] == "L", axis=0))

    @property
    def r_to_r(self):
        return self.mask_where_meta(lambda df: ~np.all(df["hemisphere"] == "R", axis=0))

    @property
    def hems_same(self):
        return self.mask_where_meta(lambda df: df["hemisphere"][0] != df["hemisphere"][1])

    @property
    def hems_different(self):
        return self.mask_where_meta(lambda df: df["hemisphere"][0] == df["hemisphere"][1])



In [None]:
df = get_metadata()
df.index -= 1
stats = scipy.io.loadmat('results/nbs/sift2_nbs.mat')["test_stat"]
sift2 = (
    NewData(scipy.io.loadmat('results/nbs/sift2_nbs.mat')["test_stat"], df)
    .threshold(3)
    .community_sort
)
avgfa = (
    NewData(scipy.io.loadmat('results/nbs/nbs.mat')["test_stat"], df)
    .threshold(3)
    .set_index(sift2.metadata.index)
    # .community_sort
)

In [None]:
l_count = np.sum(data.l_to_l.masked > 0)
r_count = np.sum(data.r_to_r.masked > 0)
print("Left -> Left:", l_count)
print("Right -> Right:", r_count)
print("L->L mean T:", np.sum(data.l_to_l.raw) / l_count)
print("R->R mean T:", np.sum(data.r_to_r.raw) / r_count)


def pipe(value, funcs):
    result = value
    for func in funcs:
        result = func(result)
    return result


def form_labels(metadata, field):
    def inner(df):
        return (
            df.join(metadata, on="source_id", lsuffix="orig")[field] +
            " ➜ " +
            df.join(metadata, on="dest_id", lsuffix="org")[field]
        )
    return inner


def select_labels(metadata, labels):
    for label in labels:
        yield label, form_labels(metadata, label)


LABELS = ["Name", "Long Name", "Gyrus", "Lobe"]


result = (
    pipe(
        autopair_labels(data.metadata.index),
        (
            lambda _: (
                np.ma.masked_where(np.dstack([data.r_to_r.mask]*2), _)
                .compressed()
                .reshape(-1, 2)
            ),
            np.sort,
            lambda _: np.unique(_, axis=0),
            pd.DataFrame
        ),
    )
    .rename(columns={0: "source_id", 1: "dest_id"})
    .assign(**{
        "t_value": lambda dff: data.raw[dff["source_id"]-1, dff["dest_id"]-1],
        **dict(select_labels(df, LABELS))
    })
    .sort_values("t_value")
    .reset_index(drop=True)
)

fig = px.line(
    result,
    x=result.index,
    y="t_value",
    width=800,
    height=600,
    hover_data=LABELS,
    labels={
        "x": "Nodes sorted by increasing degree",
        "node_size": "Node Size (# triangles)",
        "betweenness": "Betweenness",
        "degree": "Degree"
    },
    title="T value ranking",
)
fig.show()

In [None]:
def get_names(df):
    return df["Name"] + "_" + df["hemisphere"].str.lower()

def get_labels(df, fields):
    labels = np.dstack([
        np_str_join(" ➜ ", autopair_labels(df[field]).astype(str), axis=2) for field in fields
    ] or [[]])
    template = "<br>".join(f"<b>{name}:</b> %{{customdata[{i}]}}" for i, name in enumerate(fields))
    return labels, template


labels, template = get_labels(
    sift2.metadata,
    [
        # "hemisphere",
        # "Lobe",
        # "Gyrus",
        # "Long Name",
    ]
)

def lighten(color, alpha):
    result = copy.deepcopy(color)
    result.luminance *= (1 - alpha)
    result.luminance += alpha
    return result

def darken(color, alpha):
    result = copy.deepcopy(color)
    result.luminance *= alpha
    return result

@listify
def get_heatmaps(data):
    colors = map(lambda c: Color(f"#{c}"), ["21D705","fc9402","AF2De6","5b0026","8aa900"])
    for i, (datum, color) in enumerate(zip(data, colors)):
        lighter = lighten(color, 3/5)
        darker = darken(color, 1/2)
        try:
            datum, name = datum
        except ValueError:
            name = ""
        yield go.Heatmap(
            z = datum,
            x=get_names(df),
            y=get_names(df),
            hoverongaps=False,
            customdata=labels,
            hovertemplate="<br>".join([
                "(%{y}, %{x})",
                template,
                "T-value: %{z}",
            ]),
            colorscale=[[0, lighter.hex], [1, darker.hex]],
            colorbar=dict(
                x=1.02 + i*0.17,
                y=0.98,
                yanchor="top",
                title_text=name,
                title_side="right",
                len=0.6
            ),
        )

def get_figs(*data):
    for datum in data:
        fig = go.Figure()
        fig.add_traces(get_heatmaps(datum))
        # fig.add_traces(
        #     get_heatmaps([
        #         # avgfa.filled,
                
        #     ]),
        # )
        fig.update_layout(
            width=1300,
            height=1000,
            yaxis={
                "autorange": "reversed"
            },
            # plot_bgcolor="#FFFFFF"
            template="seaborn",
            title=dict(
                text="NBS Edges with a significant reduction in avgFA",
                xanchor="left",
                x = 0,
                xref="paper",
            ),
            legend_title_text="T value"
        )
        fig.add_annotation(
            x=1.05,
            y=0.98,
            text="T value",
            xanchor="left",
            yanchor="bottom",
            xref="paper",
            yref="paper",
            showarrow=False,
            font_size=16,
        )
        yield fig


figs = get_figs(
    [
        (sift2.hems_different.filled, "Interhemispheric"),
        (sift2.l_to_l.filled, "Intrahemispheric Left"),
        (sift2.r_to_r.filled, "Intrahemispheric Right"),
    ],
    [
        (avgfa.hems_different.filled, "Interhemispheric"),
        (avgfa.l_to_l.filled, "Intrahemispheric Left"),
        (avgfa.r_to_r.filled, "Intrahemispheric Right"),
    ]
)
D.Image(list(figs)[1].to_image())
# img = np.vstack([
#     iio.imread(fig.to_image(format='png'))
#     for fig in figs
# ])
# with open("plots/nbs_results.png", 'wb') as f:
#     f.write(iio.imwrite("<bytes>", img, extension=".png"))