In [None]:
import pickle
import itertools as it
import numpy as np
import pandas as pd
import networkx as nx
import os
import subprocess as sp
from pathlib import Path
from intersection.cortical_intersections import Connectome

In [None]:
from bids import BIDSLayout

layout = BIDSLayout("results/prepdwi_recon", validate=False)
# raw_layout = BIDSLayout("../..", validate=False)

In [None]:
import tarfile
import fury.io as fio
import fury.utils as futil

with open("resources/tract-assignments/hemispheric") as f:
    paths = f.read().splitlines()
def get_sorted_bundle_sizes(file):
    subject = file.entities["subject"]
    tmpdir = Path(os.environ["SLURM_TMPDIR"])
    clusters = tmpdir/"clusters"/subject
    if not clusters.exists():
        clusters.mkdir(parents=True)
        with tarfile.open(file.path, 'r:gz') as tar:
            tar.extractall(clusters)


    bundle_sizes = {}
    for path in paths:
        pld = fio.load_polydata(str(clusters/"tracts_left_hemisphere"/path))
        bundle_sizes[path] = len(futil.get_polydata_lines(pld))
    return pd.DataFrame({"bundle_size": bundle_sizes}).reset_index().rename(columns={"index": "bundle"})

# get_sorted_bundle_sizes(layout.get(subject="001", desc="sorted", space="T1w", suffix="clusters")[0])
df = pd.concat([
    get_sorted_bundle_sizes(path).assign(subject=path.entities["subject"])
    for path in layout.get(desc="sorted", space="T1w", suffix="clusters")
])

In [None]:
df

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
ddf = df.assign(category=[categories[int(x)] for x in df["subject"]])
gb = ddf.groupby(["category", "bundle"])
ddf = pd.concat([gb.mean(), gb.std().rename(columns={"bundle_size": "std"})], axis=1)
ddf = ddf.sort_values("bundle_size").reset_index()

In [None]:
ddf

In [None]:

import plotly.express as px
import plotly.graph_objs as go

fig = px.line(
    ddf,
    y="bundle_size",
    color="category",
    width=584,
    height=400,
    labels={
        "x": "Nodes sorted by increasing betweeness",
        "bundle_size": "Bundle Size (# fibers)",
        "degree": "Degree"
    },
    title="Bundle size distribution",
    # error_y="std",
    # markers=True
)
fig.update_layout(
    margin=dict(l=50, r=50, t=50, b=50),
    showlegend=False
)
fig.add_traces([*it.chain.from_iterable(
    (
        go.Scatter(
            x=dddf.index,
            y=dddf["std"]+dddf["bundle_size"],
            mode="lines",
            line=dict(width=0),
        ),
        go.Scatter(
            x=dddf.index,
            y=dddf["bundle_size"]-dddf["std"],
            mode="lines",
            line=dict(width=0),
            fill='tonexty',
            fillcolor=f'rgba{(*hex_to_rgb(px.colors.qualitative.Plotly[i-1]), 0.3)}'
        )
    ) for i in range(1,5) if (dddf := ddf[ddf["category"] == i]) is not None
)])

In [None]:
from intersection import main
from intersection import patch_merge
from intersection.cortical_intersections import Mesh, Connectome
tmpdir = Path(os.environ["SLURM_TMPDIR"])
subject = "001"
mesh_path = Path(raw_layout.get(subject=subject, suffix="smoothwm", extension=".surf.gii", hemi="L")[0].path)
intersection = main.get_intersection(mesh_path, tmpdir/"prepdwi-recon/get_hemispheric_tracts"/subject/"L", threads=8)
mesh = Mesh(mesh_path)
parcellation = patch_merge.get_parcellation(intersection.get_globbed_graph(2), mesh)
connectome = Connectome(intersection, parcellation)

In [None]:
import fury.io as fio
from intersection.patch_merge import merge_parcels
atlas = merge_parcels(parcellation, mesh)
fio.save_polydata(atlas, str(tmpdir/"atlas.vtk"))

In [None]:
categories = [
    0,
    1,
    1,
    1,
    1,
    1,
    3,
    2,
    2,
    3,
    1,
    1,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    3,
    3,
    2,
    2,
    1,
    1,
    2,
    2,
    2,
    1,
    1,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    2,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    1,
    1,
    2,
    1,
    1,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    4,
    1,
    1,
    2,
    2,
    2,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    4,
    4,
    2,
    3,
    3,
    2,
    2,
    2,
    2,
    4,
    3,
    3,
    3,
    3,
    4,
    1,
    3,
    4,
    2
]

In [None]:
def betweenness(b_vals, threshold=.5):
    vals = b_vals.values()
    b_range = max(vals) - min(vals)
    margin = b_range*threshold + min(vals)
    above = dict((str(key), val) for key,val in filter(lambda val: val[1]>margin, b_vals.items()))

    # b = dict(zip(it.count(), sorted(b_vals)))
    df = pd.DataFrame({"betweenness": above})
    return df

In [None]:
import plotly.express as px
import copy

matrix = copy.deepcopy(connectome.matrix)
np.fill_diagonal(matrix, 0)
G = nx.from_numpy_matrix(matrix)
for edge in G.edges:
    G.edges[edge]["distance"] = 1/G.edges[edge]["weight"]
df = betweenness(nx.edge_betweenness_centrality(G, weight="distance"))
print(df)
px.scatter(df,  y="betweenness")

In [None]:
f = {}
for x in [(0, 14), (0, 21), (11, 17), (11, 20), (11, 21), (14, 42)]:
    p = connectome.get_bundles_of_edge(x)
    total = sum(p.values())
    for k in p:
        frac = p[k] / total
        if frac > 0.1:
            f[k] = frac
print(f)
print(",".join([Path(intersection.get_bundle_path(x)).with_suffix(".vtp").name for x in f.keys()]))

In [None]:
def graphs():
    for bidsfile in layout.get(suffix="connectome"):
        with open(bidsfile.path, 'rb') as f:
            data = pickle.load(f)
        np.fill_diagonal(data, 0)
        G = nx.from_numpy_matrix(data)
        yield G

In [None]:
def betweenness(b_vals, threshold=.5):
    vals = b_vals.values()
    b_range = max(vals) - min(vals)
    margin = b_range*threshold + min(vals)
    above = dict(filter(lambda val: val[1]>margin, b_vals.items()))
    return list(above.keys())
    return len(above) / len(b_vals)

    # b = dict(zip(it.count(), sorted(b_vals)))
    # df = pd.DataFrame({"betweenness": b})
    # return df

In [None]:
rows = []

for bidsfile in layout.get(suffix="connectome"):
    with open(bidsfile.path, 'rb') as f:
        data = pickle.load(f)
    np.fill_diagonal(data, 0)
    G = nx.from_numpy_matrix(data)
    sub = int(bidsfile.entities['subject'])
    cat = categories[sub]
    rows.append({
        "subject": sub,
        "category": cat,
        "degree":np.mean([*zip(*G.degree)][1]),
        "num_regions": len(G.nodes),
        "transitivity": nx.transitivity(G),
        # "efficiency" nx.global_efficiency(G)
    })

df = pd.DataFrame(rows)

df["category"] = df["category"].astype("uint32")

In [None]:
rows = []
for bidsfile in layout.get(suffix="connectome"):
    with open(bidsfile.path, 'rb') as f:
        data = pickle.load(f)
    np.fill_diagonal(data, 0)
    G = nx.from_numpy_matrix(data)
    sub = int(bidsfile.entities['subject'])
    cat = categories[sub]
    for edge in G.edges:
        G.edges[edge]["distance"] = 1/G.edges[edge]["weight"]
    b_vals = nx.edge_betweenness_centrality(G, weight="distance")
    for threshold in np.arange(0, 1, 0.1):
        high_edges = betweenness(b_vals, threshold)
        weight_above = sum([G.edges[edge]["weight"] for edge in high_edges])
        total_weight = sum(list(zip(*G.edges(data="weight")))[2])
        rows.append({
            "sub": sub,
            "category": cat,
            "threshold": threshold,    
            "weight_above": weight_above
        })

df = pd.DataFrame(rows)
df["category"] = df["category"].astype("uint32")

In [None]:
Gs = iter(graphs())

In [None]:
import plotly.express as px

px.box(df, x="category", y="transitivity", points="all")

In [None]:
import plotly.express as px

grouped = df.groupby(["category", "threshold"]).mean()
px.scatter(grouped, x=grouped.index.get_level_values("threshold"), y="weight_above", color=grouped.index.get_level_values("category"))