This notebook copies and modifies [code written by Matteo Mancini](https://github.com/matteomancini/neurosnippets), which at the time of this writing was released under the [MIT license](https://github.com/matteomancini/neurosnippets/blob/master/LICENSE.md).  It was modified to use a [networkx](https://networkx.org/documentation/stable/index.html) Graph object as input for the network being visualized.

In [1]:
# Run this cell in Google Colab to retrieve the data
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/lh.pial.obj
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/icbm_fiber_mat.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/fs_region_centers_68_sort.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/freesurfer_regions_68_sort_full.txt

--2025-10-22 07:29:35--  https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/lh.pial.obj
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/matteomancini/neurosnippets/master/brainviz/interactive-network/lh.pial.obj [following]
--2025-10-22 07:29:36--  https://raw.githubusercontent.com/matteomancini/neurosnippets/master/brainviz/interactive-network/lh.pial.obj
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9891429 (9.4M) [text/plain]
Saving to: ‘lh.pial.obj’


2025-10-22 07:29:37 (103 MB/s) - ‘lh.pial.obj’ saved [9891429/9891429]

--2025-10-22 07:29:37--  https:/

In [2]:
import numpy as np
import plotly.graph_objects as go
import networkx as nx # New dependency

In [3]:
def obj_data_to_mesh3d(odata):
    # odata is the string read from an obj file
    vertices = []
    faces = []
    lines = odata.splitlines()

    for line in lines:
        slist = line.split()
        if slist:
            if slist[0] == 'v':
                vertex = np.array(slist[1:], dtype=float)
                vertices.append(vertex)
            elif slist[0] == 'f':
                face = []
                for k in range(1, len(slist)):
                    face.append([int(s) for s in slist[k].replace('//','/').split('/')])
                if len(face) > 3: # triangulate the n-polyonal face, n>3
                    faces.extend([[face[0][0]-1, face[k][0]-1, face[k+1][0]-1] for k in range(1, len(face)-1)])
                else:
                    faces.append([face[j][0]-1 for j in range(len(face))])
            else: pass


    return np.array(vertices), np.array(faces)

In [4]:
with open("lh.pial.obj", "r") as f:
    obj_data = f.read()
[vertices, faces] = obj_data_to_mesh3d(obj_data)

vert_x, vert_y, vert_z = vertices[:,:3].T
face_i, face_j, face_k = faces.T

In [5]:
cmat = np.loadtxt('icbm_fiber_mat.txt')
nodes = np.loadtxt('fs_region_centers_68_sort.txt')

labels=[]
with open("freesurfer_regions_68_sort_full.txt", "r") as f:
    for line in f:
        labels.append(line.strip('\n'))

In [6]:
# Instantiate Graph and add nodes (with their coordinates)
G = nx.Graph()

for idx, node in enumerate(nodes):
    G.add_node(idx, coord=node)

# Add made-up colors for the nodes as node attribute
colors_data = {node: ('gray' if node > 10 else 'red') for node in G.nodes}
nx.set_node_attributes(G, colors_data, name="color")

# Add edges
[source, target] = np.nonzero(np.triu(cmat)>0.01)
edges = list(zip(source, target))

G.add_edges_from(edges)

In [8]:
# Get node coordinates from node attribute
nodes_x = [data['coord'][0] for node, data in G.nodes(data=True)]
nodes_y = [data['coord'][1] for node, data in G.nodes(data=True)]
nodes_z = [data['coord'][2] for node, data in G.nodes(data=True)]

edge_x = []
edge_y = []
edge_z = []
for s, t in edges:
    edge_x += [nodes_x[s], nodes_x[t]]
    edge_y += [nodes_y[s], nodes_y[t]]
    edge_z += [nodes_z[s], nodes_z[t]]

# Get node colors from node attribute
node_colors = [data['color'] for node, data in G.nodes(data=True)]

In [9]:
fig = go.Figure()

# Changed color and opacity kwargs
fig.add_trace(go.Mesh3d(x=vert_x, y=vert_y, z=vert_z, i=face_i, j=face_j, k=face_k,
                        color='gray', opacity=0.1, name='', showscale=False, hoverinfo='none'))

fig.add_trace(go.Scatter3d(x=nodes_x, y=nodes_y, z=nodes_z, text=labels,
                           mode='markers', hoverinfo='text', name='Nodes',
                           marker=dict(
                                       size=5, # Changed node size...
                                       color=node_colors # ...and color
                                      )
                           ))
fig.add_trace(go.Scatter3d(x=edge_x, y=edge_y, z=edge_z,
                           mode='lines', hoverinfo='none', name='Edges',
                           opacity=0.3, # Added opacity kwarg
                           line=dict(color='pink') # Added line color
                           ))

fig.update_layout(
    scene=dict(
        xaxis=dict(showticklabels=False, visible=False),
        yaxis=dict(showticklabels=False, visible=False),
        zaxis=dict(showticklabels=False, visible=False),
    ),
    width=800, height=600
)

fig.show()

# **My part**

In [10]:
"""
Create a GIF of dynamic connectivity on a 3D brain mesh by aggregating
channel-level connectivity (from an .npz) to 68 Freesurfer regions.

Notes:
- This script DOES NOT recompute connectivity. It uses your existing .npz file
  with keys: 'ch_names', 'times_s', 'PLI', 'wPLI', 'AEC'.
- The channel->region mapping is approximate: channels are assigned to lobes
  (F, C, P, O, T) and then evenly mapped to the Freesurfer regions that
  mention that lobe in their label. Replace `auto_map_channels_to_regions`
  with exact mapping if you have electrode coords or a custom mapping.
- Plotly's write_image (Kaleido) is used to save frames. Install with:
    pip install kaleido
- For large numbers of frames, set downsample_frames > 1 to reduce GIF size/time.
"""

import os
import numpy as np
import networkx as nx
import plotly.graph_objects as go
import imageio
import math
from tqdm import tqdm

# --------------------------
# User settings
# --------------------------
npz_file = "/content/sub-001_ses-1_task-EyesClosed_acq-post_eeg_Alpha_connectivity.npz"  # your npz
mesh_obj = "/content/lh.pial.obj"  # brain mesh
centers_file = "/content/fs_region_centers_68_sort.txt"  # 68 region centers
labels_file = "/content/freesurfer_regions_68_sort_full.txt"  # labels for 68 regions

out_dir = "results/brain_gif"
os.makedirs(out_dir, exist_ok=True)

metric = "AEC"     # choose "AEC", "PLI" or "wPLI"
edge_percentile = 90   # global percentile threshold for edges (lower = more edges)
frame_downsample = 1    # use every nth frame (1 = all frames). Increase to reduce size.
gif_name = os.path.join(out_dir, f"sub-001_{metric}_brain_dynamic.gif")
tmp_frames_dir = os.path.join(out_dir, "tmp_frames")
os.makedirs(tmp_frames_dir, exist_ok=True)

# --------------------------
# Helpers: load mesh & region centers
# --------------------------
def obj_data_to_mesh3d_text(filename):
    with open(filename, "r") as f:
        odata = f.read()
    vertices = []
    faces = []
    for line in odata.splitlines():
        parts = line.split()
        if not parts:
            continue
        if parts[0] == "v":
            vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
        elif parts[0] == "f":
            face = []
            for k in parts[1:]:
                # handle v//vn or v/vt/vn
                idx = int(k.split("/")[0]) - 1
                face.append(idx)
            if len(face) == 3:
                faces.append(face)
            else:
                # triangulate polygon face
                for i in range(1, len(face) - 1):
                    faces.append([face[0], face[i], face[i + 1]])
    verts = np.array(vertices)
    faces = np.array(faces)
    return verts, faces

# --------------------------
# Map channels to lobes -> regions (approximate automatic mapping)
# --------------------------
def lobe_from_channel(ch):
    ch_upper = ch.upper()
    if ch_upper.startswith(("FP", "AF", "F")):
        return "frontal"
    if ch_upper.startswith(("C", "CZ")):
        return "central"
    if ch_upper.startswith(("P", "PO")):
        return "parietal"
    if ch_upper.startswith(("O", "OP", "Oz".upper())):
        return "occipital"
    if ch_upper.startswith(("T")):
        return "temporal"
    # fallback
    return "other"

def group_region_indices_by_lobe(labels):
    """
    Return dict: lobe -> list of region indices (0-based) whose label
    contains keywords for that lobe. labels are strings from freesurfer file.
    """
    lobe_keywords = {
        "frontal": ["frontal", "gyrus rectus", "insula"],  # rough
        "central": ["precentral", "postcentral", "paracentral"],
        "parietal": ["parietal", "supramarginal", "precuneus"],
        "occipital": ["occipital", "cuneus", "lingual"],
        "temporal": ["temporal", "entorhinal", "fusiform"],
        "other": []
    }
    mapping = {k: [] for k in lobe_keywords}
    for i, lab in enumerate(labels):
        llab = lab.lower()
        found = False
        for lobe, keys in lobe_keywords.items():
            for k in keys:
                if k in llab:
                    mapping[lobe].append(i)
                    found = True
                    break
            if found:
                break
        if not found:
            mapping["other"].append(i)
    return mapping

def auto_map_channels_to_regions(ch_names, labels):
    """
    Evenly assign each channel to one region index in the appropriate lobe group.
    Returns dict channel_index -> region_index.
    """
    lobe_regions = group_region_indices_by_lobe(labels)
    # Build ch lists per lobe
    ch_by_lobe = {}
    for i, ch in enumerate(ch_names):
        l = lobe_from_channel(ch)
        ch_by_lobe.setdefault(l, []).append(i)
    # Now assign channels in each lobe to region indices by cycling or one-to-one
    ch2reg = {}
    for lobe, ch_idxs in ch_by_lobe.items():
        regs = lobe_regions.get(lobe, [])
        if not regs:
            regs = lobe_regions.get("other", [])
        if not regs:
            # last fallback: use all regions
            regs = list(range(len(labels)))
        # cycle through regs
        for idx, ch_idx in enumerate(ch_idxs):
            reg = regs[idx % len(regs)]
            ch2reg[ch_idx] = reg
    # Channels not assigned (if any) -> assign to 'other' region or 0
    for i in range(len(ch_names)):
        if i not in ch2reg:
            ch2reg[i] = 0
    return ch2reg

# --------------------------
# Build region-level connectivity
# --------------------------
def aggregate_channel_to_region(conn3d, ch2reg, n_regions):
    """
    conn3d: (n_windows, n_ch, n_ch)
    ch2reg: dict channel_index -> region_index
    returns region_conn: shape (n_windows, n_regions, n_regions)
    region_conn computed as mean of all channel-channel pairs assigned to region pairs.
    """
    n_windows = conn3d.shape[0]
    region_conn = np.zeros((n_windows, n_regions, n_regions), dtype=np.float32)
    counts = np.zeros((n_regions, n_regions), dtype=np.int32)  # used only for sanity
    # precompute channel lists per region
    reg2ch = {r: [] for r in range(n_regions)}
    for ch_idx, reg_idx in ch2reg.items():
        reg2ch[reg_idx].append(ch_idx)

    for t in range(n_windows):
        A = conn3d[t]
        for i in range(n_regions):
            for j in range(n_regions):
                ch_i = reg2ch.get(i, [])
                ch_j = reg2ch.get(j, [])
                if len(ch_i) == 0 or len(ch_j) == 0:
                    region_conn[t, i, j] = 0.0
                else:
                    # compute mean over all channel pairs (i_ch, j_ch)
                    vals = []
                    for a in ch_i:
                        for b in ch_j:
                            vals.append(A[a, b])
                    region_conn[t, i, j] = np.nanmean(vals) if vals else 0.0
    return region_conn

# --------------------------
# MAIN: load files and produce frames
# --------------------------
print("Loading connectivity (.npz)...")
data = np.load(npz_file, allow_pickle=True)
ch_names = [str(x) for x in data["ch_names"].tolist()]
times_s = data["times_s"]
conn_all = data[metric]   # shape (n_windows, n_ch, n_ch)
n_windows, n_ch, _ = conn_all.shape
print(f"Connectivity loaded: metric={metric}, shape={conn_all.shape}")

print("Loading brain mesh and region info...")
verts, faces = obj_data_to_mesh3d_text(mesh_obj)
centers = np.loadtxt(centers_file)  # shape (68, 3)
labels = []
with open(labels_file, "r") as f:
    labels = [line.strip() for line in f]
n_regions = centers.shape[0]
print(f"Mesh vertices: {verts.shape}, faces: {faces.shape}, regions: {n_regions}")

# Build or compute mapping from channels to regions
print("Mapping channels to regions (automatic approximate mapping)...")
ch2reg = auto_map_channels_to_regions(ch_names, labels)
# Optional: print mapping summary
from collections import defaultdict
reg2chs = defaultdict(list)
for ch_idx, reg_idx in ch2reg.items():
    reg2chs[reg_idx].append(ch_idx)
print("Example mapping (region index -> #channels):")
for r, chs in list(reg2chs.items())[:12]:
    print(r, labels[r], len(chs))

# Aggregate channel connectivity to region-level connectivity
print("Aggregating channel-level connectivity to region-level...")
region_conn = aggregate_channel_to_region(conn_all, ch2reg, n_regions)  # (n_windows, n_reg, n_reg)

# Compute a global threshold (absolute) to show edges
all_offdiag = region_conn[:, ~np.eye(n_regions, dtype=bool)].reshape(-1)
if np.any(np.isfinite(all_offdiag)) and np.any(all_offdiag > 0):
    thr_global = np.percentile(all_offdiag[all_offdiag > 0], edge_percentile)
else:
    thr_global = 0.0
print(f"Global edge percentile threshold ({edge_percentile}%) = {thr_global:.4f}")

# Precompute region nodes coordinates
nodes_x = centers[:, 0]
nodes_y = centers[:, 1]
nodes_z = centers[:, 2]

# Pre-build mesh and static node trace (to reuse in frames)
vert_x, vert_y, vert_z = verts[:,0], verts[:,1], verts[:,2]
face_i, face_j, face_k = faces.T

# Color nodes by simple lobe coloring for visual clarity
lobe_color_map = {
    "frontal": "red", "central": "orange", "parietal": "green",
    "occipital": "blue", "temporal": "yellow", "other": "gray"
}
node_colors = []
for r in range(n_regions):
    lab = labels[r].lower()
    if "frontal" in lab:
        node_colors.append(lobe_color_map["frontal"])
    elif "precentral" in lab or "postcentral" in lab or "paracentral" in lab:
        node_colors.append(lobe_color_map["central"])
    elif "parietal" in lab or "supramarginal" in lab or "precuneus" in lab:
        node_colors.append(lobe_color_map["parietal"])
    elif "occipital" in lab or "cuneus" in lab or "lingual" in lab:
        node_colors.append(lobe_color_map["occipital"])
    elif "temporal" in lab or "fusiform" in lab:
        node_colors.append(lobe_color_map["temporal"])
    else:
        node_colors.append(lobe_color_map["other"])

# Create frames
frame_files = []
frames_to_process = list(range(0, n_windows, frame_downsample))
print(f"Writing {len(frames_to_process)} frames (downsample={frame_downsample}) to {tmp_frames_dir} ...")

for idx, t in tqdm(enumerate(frames_to_process), total=len(frames_to_process)):
    t_int = int(t)
    mat = region_conn[t_int]
    # Build edges above threshold
    edges = []
    weights = []
    for i in range(n_regions):
        for j in range(i+1, n_regions):
            w = mat[i, j]
            if np.abs(w) >= thr_global and np.isfinite(w):
                edges.append((i, j, w))
                weights.append(np.abs(w))
    # scale widths for visualization
    if len(weights) > 0:
        wmin, wmax = np.min(weights), np.max(weights)
    else:
        wmin, wmax = 0.0, 1.0

    # Build Plotly figure
    fig = go.Figure()
    # Brain mesh (transparent)
    fig.add_trace(go.Mesh3d(
        x=vert_x, y=vert_y, z=vert_z, i=face_i, j=face_j, k=face_k,
        color='lightgrey', opacity=0.15, showscale=False
    ))
    # Region nodes
    fig.add_trace(go.Scatter3d(
        x=nodes_x, y=nodes_y, z=nodes_z, mode='markers+text',
        marker=dict(size=6, color=node_colors),
        text=[f"{i}: {labels[i]}" for i in range(n_regions)],
        textposition="top center",
        hoverinfo='text',
        name='regions'
    ))
    # edges as lines
    edge_x = []
    edge_y = []
    edge_z = []
    edge_widths = []
    edge_colors = []
    for (i,j,w) in edges:
        edge_x += [nodes_x[i], nodes_x[j], None]
        edge_y += [nodes_y[i], nodes_y[j], None]
        edge_z += [nodes_z[i], nodes_z[j], None]
        # width scaled
        if wmax - wmin > 1e-9:
            width = 1.0 + 6.0 * (abs(w) - wmin) / (wmax - wmin)
        else:
            width = 1.0 + 6.0 * (abs(w))
        edge_widths.append(width)
        edge_colors.append('red')  # you can set color scale by w if desired

    if edge_x:
        # Plot all edges as one scatter (Plotly doesn't support varying widths per segment easily),
        # so approximate by scaling the line thickness via `line.width`. If you want per-edge width,
        # you must add a separate trace per edge (slower).
        fig.add_trace(go.Scatter3d(
            x=edge_x, y=edge_y, z=edge_z, mode='lines',
            line=dict(color='darkred', width=2),  # fixed width; see note above
            opacity=0.7,
            hoverinfo='none',
            name='edges'
        ))

    # Layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(showticklabels=False, visible=False),
            yaxis=dict(showticklabels=False, visible=False),
            zaxis=dict(showticklabels=False, visible=False),
        ),
        margin=dict(l=0, r=0, t=30, b=0),
        title=f"{metric} dynamic connectivity — {times_s[t_int]:.1f}s (frame {t_int})",
        width=900, height=700
    )

    # Save static image frame (requires kaleido)
    frame_path = os.path.join(tmp_frames_dir, f"frame_{idx:05d}.png")
    try:
        fig.write_image(frame_path, engine="kaleido", scale=2)
    except Exception as e:
        # If write_image fails, fallback to saving html (slower) or skipping
        print("Warning: fig.write_image failed (kaleido missing?). Try `pip install kaleido`.")
        raise e
    frame_files.append(frame_path)

print(f"Frames written: {len(frame_files)}")

# --------------------------
# Build GIF
# --------------------------
print("Building GIF (this may take a while)...")
with imageio.get_writer(gif_name, mode='I', fps=6) as writer:
    for fname in tqdm(frame_files):
        img = imageio.imread(fname)
        writer.append_data(img)

print("GIF saved at:", gif_name)
print("Done. You can remove the tmp_frames directory to free space.")


Loading connectivity (.npz)...


FileNotFoundError: [Errno 2] No such file or directory: 'sub-001_connectivity.npz'