# Spatiotemporal layer dynamics underlying MERF in a specific ROI

In this tutorial:
- retreive source time series
- define the region of interest (ROI)
- compute the CSDs for these region
- bin and extract layer-specific CSD activity
- define the time window of interests (WOI)
- average (using max or mean) the activity within these time
- define the surface and interpolate this activity
- plot the activity over time
- find the clusters in each regions: assess their significativity

In [None]:
import os
import shutil
import numpy as np
import nibabel as nib
from joblib import Parallel, delayed
import k3d
from scipy import stats
from matplotlib import colors
import matplotlib.pyplot as plt
import tempfile
import glob
from scipy.signal import resample
from scipy.ndimage import gaussian_filter1d

from lameg.laminar import compute_csd
from lameg.simulate import run_current_density_simulation, run_dipole_simulation
from lameg.invert import invert_ebb, coregister, load_source_time_series
from lameg.util import get_fiducial_coords
from lameg.viz import plot_csd
import spm_standalone

from utils import convert_native_to_fsaverage, get_bigbrain_layer_boundaries, get_roi_idx, find_clusters, compute_activity_over_time, extract_layer_csd

from lameg.surf import interpolate_data
from lameg.viz import show_surface, color_map
import io
from PIL import Image
from base64 import b64decode
import matplotlib.colors as mcolors
import matplotlib.cm as cm

In [None]:
%env OMPI_TMPDIR=/scratch
%env TMPDIR=/scratch

In [None]:
subj_id='sub-001'
ses_id = 'ses-01'
epoch='motor'
c_idx=3
subj_dir=os.path.join('/home/common/bonaiuto/cued_action_meg/derivatives/processed',subj_id)
subj_dir_sss=os.path.join('/home/common/bonaiuto/cued_action_meg/derivatives/processed_sss',subj_id)
subj_surf_dir=os.path.join(subj_dir,'surf')
multilayer_mesh_fname = os.path.join(subj_surf_dir, 'multilayer.11.ds.link_vector.fixed.gii')
pial_mesh_fname = os.path.join(subj_surf_dir,'pial.ds.link_vector.fixed.gii')

# External dependencies of the doc
mri_fname = os.path.join(subj_dir, 't1w.nii')
smooth_file = os.path.join(subj_surf_dir, 'FWHM5.00_multilayer.11.ds.link_vector.fixed.mat')
data_file=os.path.join(subj_dir_sss, ses_id, f'spm/pmcspm_converted_autoreject-{subj_id}-{ses_id}-{epoch}-epo.mat')
out_dir=os.path.join('./data', subj_id, ses_id, f'{subj_id}_{ses_id}_c{c_idx}_{epoch}_model_inv')

out_dir_chunks = os.path.join(out_dir, 'csd_chunks_signif')
os.makedirs(out_dir_chunks, exist_ok=True)

fiducial_fname='/home/common/bonaiuto/cued_action_meg/raw/participants.tsv'
nas, lpa, rpa=get_fiducial_coords(subj_id, fiducial_fname)

%env SUBJECTS_DIR=/home/common/bonaiuto/cued_action_meg/derivatives/processed/fs/
pial_ds = 'pial.ds.gii'

In [None]:
# Extract base name and path of data file
data_path, data_file_name = os.path.split(data_file)
data_base = os.path.splitext(data_file_name)[0]

# Copy data files to tmp directory
shutil.copy(
    os.path.join(data_path, f'{data_base}.mat'), 
    os.path.join(out_dir, f'{data_base}.mat')
)
shutil.copy(
    os.path.join(data_path, f'{data_base}.dat'), 
    os.path.join(out_dir, f'{data_base}.dat')
)
shutil.copy(
    mri_fname, 
    os.path.join(out_dir, 't1w.nii')
)
shutil.copy(
    smooth_file, 
    os.path.join(out_dir, 'FWHM5.00_multilayer.11.ds.link_vector.fixed.mat')
)
shutil.copy(
    multilayer_mesh_fname, 
    os.path.join(out_dir, 'multilayer.11.ds.link_vector.fixed.gii')
)

# Construct base file name for simulations
mri_fname = os.path.join(out_dir, 't1w.nii')
smooth_file = os.path.join(out_dir, 'FWHM5.00_multilayer.11.ds.link_vector.fixed.mat')
multilayer_mesh_fname = os.path.join(out_dir, 'multilayer.11.ds.link_vector.fixed.gii')
base_fname = os.path.join(out_dir, f'{data_base}.mat')

In [None]:
spm = spm_standalone.initialize()

In [None]:
# Compute the number of vertices per layer
mesh = nib.load(multilayer_mesh_fname)
pial_mesh = nib.load(pial_mesh_fname)
n_layers = 11
verts_per_surf = int(mesh.darrays[0].data.shape[0]/n_layers)

In [None]:
layer_verts = [l*int(verts_per_surf) for l in range(n_layers)]
layer_coords = mesh.darrays[0].data[layer_verts,:]
thickness = np.sqrt(np.sum((layer_coords[0,:]-layer_coords[-1,:])**2))

In [None]:
s_rate = 600

In [None]:
# Patch size to use for inversion (in this case it matches the simulated patch size)
patch_size = 5
# Number of temporal modes to use for EBB inversion
n_temp_modes = 4

# Coregister data to multilayer mesh
coregister(
    nas, 
    lpa, 
    rpa, 
    mri_fname, 
    multilayer_mesh_fname, 
    base_fname,
    spm_instance=spm
)

# Run inversion - save MU for extraction of single trials
[_,_,MU] = invert_ebb(
    multilayer_mesh_fname, 
    base_fname, 
    n_layers, 
    patch_size=patch_size, 
    n_temp_modes=n_temp_modes,
    return_mu_matrix=True,
    spm_instance=spm
)

In [None]:
MU_fname = os.path.join(out_dir, 'MU.npy') 
# np.save(MU_fname, MU)

MU = np.load(MU_fname, allow_pickle = True) 

In [None]:
# Get source time series for each layer and vertex
mean_layer_ts, time, _ = load_source_time_series(base_fname)

In [None]:
mean_layer_ts_fname = os.path.join(out_dir, 'layer_ts.npy') 
#np.save(mean_layer_ts_fname, mean_layer_ts)

mean_layer_ts = np.load(mean_layer_ts_fname) 

##### Get the indexes of the vertices of the region of interest

In [None]:
# roi_idx = get_roi_idx(subj_id, subj_surf_dir, 'lh', ['precentral','paracentral','postcentral', 
#                                                      'rostralmiddlefrontal','caudalmiddlefrontal',
#                                                      'superiorfrontal','parsopercularis',
#                                                      'parstriangularis',
#                                                      'caudalanteriorcingulate',
#                                                     'frontalpole'], pial_mesh)

roi_idx = get_roi_idx(subj_id, subj_surf_dir, 'lh', ['precentral','paracentral','postcentral'], pial_mesh)

Get the big brain layer boundaries in this region of interest

In [None]:
#bb_layer_bound = get_bigbrain_layer_boundaries(subj_id, subj_surf_dir, subj_coord=None)

bb_layer_bound_fname = os.path.join(out_dir, 'bb_layer_bound.npy') 
#np.save(bb_layer_bound_fname, bb_layer_bound)

In [None]:
bb_layer_bound = np.load(bb_layer_bound_fname, allow_pickle = True) 
bb_lb_roi = bb_layer_bound[:,roi_idx]

In [None]:
bb_lb_roi.shape

In [None]:
bb_lb_roi = bb_lb_roi.T

In [None]:
#extract mean_layer_ts from these vertices
mean_layer_ts_roi = mean_layer_ts[roi_idx]

### Compute the CSD for only these vertices
Here we compute in chunks and write them to disk to save temporary memory space

In [None]:
def compute_csd_for_vertex(vertex):
    vert=[l*int(verts_per_surf)+vertex for l in range(n_layers)]
    layer_coords = mesh.darrays[0].data[vert, :]
    thickness = np.linalg.norm(layer_coords[0, :] - layer_coords[-1, :])
    
    csd, sm_csd = compute_csd(
        mean_layer_ts[vert, :],
        thickness, #as thickness is computed only for vertexes in roi_idx
        sfreq = 600,
        smoothing='cubic'
    )
    return csd, sm_csd

In [None]:
saved_chunks = {
    int(os.path.basename(f).split("_")[1].split(".")[0])
    for f in glob.glob(os.path.join(out_dir_chunks, "csd_*.npy"))
}

chunk_size = 500
for i in range(0, len(roi_idx), chunk_size):
    chunk_idx = i // chunk_size
    if chunk_idx in saved_chunks:
        print(f"Skipping chunk {chunk_idx} (already saved)")
        continue

    print(f"Processing chunk {chunk_idx}...")

    chunk = roi_idx[i:i + chunk_size]
    results = Parallel(n_jobs=-1)(
        delayed(compute_csd_for_vertex)(vertex) for vertex in chunk
    )

    results = [res for res in results if res[0] is not None]
    if results:
        csd_chunk, sm_csd_chunk = zip(*results)
        #np.save(os.path.join(out_dir_chunks, f"csd_{chunk_idx:04d}.npy"), np.stack(csd_chunk))
        np.save(os.path.join(out_dir_chunks, f"smooth_csd_{chunk_idx:04d}.npy"), np.stack(sm_csd_chunk))
        print(f"Saved chunk {chunk_idx} to disk")
    else:
        print(f"Chunk {chunk_idx} returned no valid results")

In [None]:
#csd_files = sorted(glob.glob(f"{out_dir_chunks}/csd_*.npy"))
smooth_files = sorted(glob.glob(f"{out_dir_chunks}/smooth_csd_*.npy"))

#csd_emp_roi = np.concatenate([np.load(f) for f in csd_files])
sm_csd_roi = np.concatenate([np.load(f) for f in smooth_files])

#csd_emp_roi = np.stack(csd_emp_roi)
sm_csd_roi = np.stack(sm_csd_emp_roi)

In [None]:
sm_csd_roi.shape

### Extract layer-specific CSD
Based on Big Brain Atlas layer boundaries

In [None]:
csd_L5 = extract_layer_csd(sm_csd_roi, bb_lb_roi, roi_idx, 'L5')
csd_L2_3 = extract_layer_csd(sm_csd_roi, bb_lb_roi, roi_idx, 'L2_3')

In [None]:
time = np.linspace(-1, 1, 1201) #in case you don't have it

### Extract in a specific time window

You can either use specific time windows or create sliding time windows (with or without overlapping)

In [None]:
wd_size = 0.05  # 50 ms
step_size = 0.05  # for non-overlapping windows; this reduce for overlap
start_time = time[300] #start_time = time[0] 
end_time = time[900] #end_time = time[-1]

#woi = [(-0.5, -0.1), (-0.1, 0), (0, 0.12), (0.12, 0.3)]
baseline_woi_idx = [(0,300)] 
woi = []
t = start_time
while t + wd_size <= end_time:
    woi.append((t, t + wd_size))
    t += step_size

woi_idx = [(
    (np.abs(time - start)).argmin(),
    (np.abs(time - end)).argmin()
) for start, end in woi]

In [None]:
mcsd_L5 = compute_activity_over_time(csd_L5, woi_idx, roi_idx, method='mean_abs')
mcsd_L2_3 = compute_activity_over_time(csd_L2_3, woi_idx, roi_idx, method='mean_abs')
#get the average activity
m_pial_ts = compute_activity_over_time(mean_layer_ts_roi, woi_idx, roi_idx, method='mean_abs')

#compute baseline activity
baseline_mcsd_L5 = compute_activity_over_time(csd_L5, baseline_woi_idx, roi_idx, method='mean_abs')
baseline_mcsd_L2_3 = compute_activity_over_time(csd_L2_3, baseline_woi_idx, roi_idx, method='mean_abs')
baseline_pial_ts = compute_activity_over_time(mean_layer_ts_roi, baseline_woi_idx, roi_idx, method='mean_abs')

#baseline correct the values: should probably do this before with the time series? 
mcsd_L5 = mcsd_L5 - baseline_mcsd_L5
mcsd_L2_3 = mcsd_L2_3 - baseline_mcsd_L2_3
m_pial_ts = m_pial_ts - baseline_pial_ts

In [None]:
mcsd_L5.shape # number of time windows x vertices mean activity in this time window

In [None]:
# create a dictionnary with the absolute values
group_data = {
    'v_mcsd_L5': {'data': np.abs(mcsd_L5)},
    'v_mcsd_L2_3': {'data': np.abs(mcsd_L2_3)},
    'v_m_pial_ts': {'data': np.abs(m_pial_ts)},
}

In [None]:
group_data['v_m_pial_ts']['data'].shape

### Interpolate activity and plot it
For this you need to define both the original and downsampled surface. You will pad all the non-roi vertices with 0 to then plot on the overall surface

In [None]:
orig_inflated=nib.load(os.path.join(subj_surf_dir, 'inflated.gii'))
ds_inflated=nib.load(os.path.join(subj_surf_dir, 'inflated.ds.gii'))

In [None]:
nb_vertices_ds = ds_inflated.darrays[0].data.shape[0]

(Optional) find the coordinates of the maximum activity 

In [None]:
pial_coords = ds_inflated.darrays[0].data

In [None]:
for key in group_data:
    coords_list = []
    data_array = group_data[key]['data']
    for t_wd_ix in range(len(woi_idx)):
        max_ix = np.nanargmax(data_array[t_wd_ix])
        coords_i = pial_coords[max_ix]
        coords_list.append(coords_i)

    group_data[key]['max_coords'] = np.array(coords_list)

In [None]:
group_data['v_m_pial_ts']['max_coords'][0]

Interpolate for display on the original inflated surface (can choose not to display on inflated)

In [None]:
for group_key in group_data:
    data = group_data[group_key]['data']
    # Pad the data with 0s when not in the selected vertices 
    data_to_inter = np.zeros((data.shape[0],nb_vertices_ds))
    # fill in ROI absolute values for each time window
    data_to_inter[:,roi_idx] = data
    results = Parallel(n_jobs=-1, prefer="processes")(
        delayed(interpolate_data)(
            orig_inflated,
            ds_inflated,
            data_to_inter[t_wd_ix,:]
        )
        for t_wd_ix in range(data.shape[0])
    )
    group_data[group_key]['interpolated'] = np.array(results)

In [None]:
group_data['v_m_pial_ts']['interpolated'][2].shape

Plot a specific group and time window (could loop this later)

In [None]:
chosen_group = 'v_mcsd_L5'
chosen_interp_data = group_data[chosen_group]['interpolated']
chosen_t_ix = 0

In [None]:
#Specify the bounds: should be the same to compare between layers
vmin, vmax = np.min(chosen_interp_data), np.max(chosen_interp_data)
vmin, vmax = 0, 0.12

In [None]:
colors,_ = color_map(
        chosen_interp_data[chosen_t_ix], 
        "Spectral_r", 
        vmin, 
        vmax,
        norm='N' #or 'TS' if centered around 0 (so not the absolute value)
        )

In [None]:
# #if want to do all: 
# colors_all = []
# for t_wd_ix in range(len(woi_idx)):
#         colors,_ = color_map(
#         chosen_interp_data[t_wd_ix], ### PROBLEM HERE
#         "Spectral_r", 
#         vmin, 
#         vmax,
#         norm='N' #or 'TS' if centered around 0 (so not the absolute value)
#         )
#         colors_all.append(colors)
# group_data[group_key]['color'] = np.array(colors_all)

In [None]:
precentral_view = [-178.9392420191627,-45.354421787933546,128.49784196033622,
                   8, 2.5, -32,
                   0.7760071820687688, -0.2981949773300084, 0.5641742717218174]

In [None]:
v = 3187
coords_v = ds_inflated.darrays[0].data[v,:]

In [None]:
# plot = show_surface(orig_inflated, vertex_colors=colors, info=True, 
#                     camera_view=precentral_view, 
#                     coords = group_data[chosen_group]['coords'][chosen_t_ix], coord_size=1, coord_color=[0,0,255])

plot = show_surface(orig_inflated, vertex_colors=colors, info=True, 
                    camera_view=precentral_view, 
                    coords = coords_v, coord_size=4, coord_color=[0,0,255])

In [None]:
#plot.camera #to get the orientation we want and can replace precentral_view
plot.fetch_screenshot()

In [None]:
def add_subplot_label(ax, label, x=-.21, y=1.225, fontsize=26):
    ax.text(x, y, label,  # Adjust left of y-axis
            transform=ax.transAxes,
            fontsize=fontsize, va='top', ha='right')

In [None]:
# Decode the screenshot
image_data = b64decode(plot.screenshot)
image = Image.open(io.BytesIO(image_data))
image_array = np.array(image)

# Set up figure and axis
fig, ax = plt.subplots(figsize=(24, 16))
ax.imshow(image_array)
ax.axis('off')  # Hide axes

# Create a ScalarMappable for the colorbar
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
scalar_mappable = cm.ScalarMappable(norm=norm, cmap="Spectral_r")

# Add colorbar
cbar = plt.colorbar(scalar_mappable, ax=ax, shrink=0.5, aspect=20, pad=0.02)
cbar.set_label(f"{chosen_group} max power on tw{woi_idx[chosen_t_ix]}", fontsize=20)

# Set two ticks: at min (vmin) and max (vmax)
cbar.set_ticks([vmin, vmax])
cbar.set_ticklabels([f"Low CSD:{round(vmin,2)}", f"High CSD:{round(vmax,2)}"])

add_subplot_label(ax, 'a', fontsize=54)

# Show and save figure
#plot_fname = os.path.join(out_dir,f'fig_{chosen_group}_csd_wd{chosen_t_ix}_roi_{method}'.png')
#plt.savefig(plot_fname)

### Find clusters
Find clusters and assess their significance. One important point here is that to find clusters we need to find the faces in the downsampled inflated data - otherwise the clusters will not be found/ or accurate

In [None]:
faces = ds_inflated.darrays[1].data
mask = np.all(np.isin(faces, roi_idx), axis=1)
roi_faces_global = faces[mask]

# Build mapping from global to local index
global_to_local = {v: i for i, v in enumerate(roi_idx)}

# Reindex faces from global to local indices
roi_faces_local = np.array([[global_to_local[v] for v in tri] for tri in roi_faces_global])

In [None]:
thresh = 98

In [None]:
# to assess significance of clusters: on mean using spatial permutations
# !pip install scikit-learn
# !pip install brainspace
# from brainspace.null_models import SpinPermutations

# n_permutations = 100
# sphere_coords = ds_inflated.darrays[0].data[roi_idx,:]
# spin_model = SpinPermutations(n_rep=n_permutations, random_state=42)
# spin_model.fit(sphere_coords)

In [None]:
# for group_name, group_info in group_data.items():
#     data = group_info['data']  # (time windows x vertices)
#     group_info['clusters'] = {}  # add cluster subdict

#     for t_i, time_idx in enumerate(woi_idx):
#         vals = np.abs(data[t_i][roi_idx])
#         vals = np.nan_to_num(vals)

#         cluster_thresh = np.nanpercentile(vals, thresh)
#         mask = np.where(vals >= cluster_thresh)[0]

#         clusters = find_clusters(roi_faces_local, mask, n_hops=3)
#         obs_masses = [np.sum(vals[c]) for c in clusters]

#         null_vals = spin_model.randomize(vals)

#         null_masses = []
#         for i in range(n_permutations):
#             permuted_vals = null_vals[:, i]

#             s_thresh = np.nanpercentile(permuted_vals, thresh)
#             s_mask = np.where(permuted_vals >= s_thresh)[0]

#             s_clusters = find_clusters(roi_faces_local, s_mask, n_hops=3)
#             s_masses = [np.sum(permuted_vals[c]) for c in s_clusters] if s_clusters else [0]
#             null_masses.append(np.max(s_masses))

#         null_masses = np.array(null_masses)
#         pvals = [np.mean(null_masses >= m) for m in obs_masses]

#         sig_clusters = [c for c, p in zip(clusters, pvals) if p < 0.05]

#         max_v_cluster = []
#         for cluster in sig_clusters:
#             cluster_vals = vals[cluster]
#             max_c_idx = np.argmax(cluster_vals)
#             max_v_idx = roi_idx[cluster[max_c_idx]]
#             max_v_cluster.append(max_v_idx)

#         # Store in group_data[group_name]['clusters']
#         group_info['clusters'][t_i] = {
#             "sig_clusters": sig_clusters,
#             "pvals": pvals,
#             "max_vertex_ids": max_v_cluster
#         }

In [None]:
highest_vert_idx = []

for group_name, group_info in group_data.items():
    data = group_info['data']  # (time windows x roi_vertices)
    group_info['clusters'] = {}  # add cluster subdict

    for t_i, time_idx in enumerate(woi_idx):
        vals = data[t_i]
        vals = np.nan_to_num(vals)

        cluster_thresh = np.nanpercentile(vals, thresh)
        mask = np.where(vals >= cluster_thresh)[0]
        
        highest_vert_idx.append(mask)

        clusters = find_clusters(roi_faces_local, mask, n_hops=3)
        
        max_v_cluster = []
        for cluster in clusters:
            cluster_vals = vals[cluster]
            max_c_idx = np.argmax(cluster_vals)
            max_v_idx = roi_idx[cluster[max_c_idx]]
            max_v_cluster.append(max_v_idx)

        # Store in group_data[group_name]['clusters']
        group_info['clusters'][t_i] = {
            "clusters": clusters,
            "max_vertex_ids": max_v_cluster
        }

In [None]:
group_data['v_mcsd_L2_3']['clusters'][chosen_t_ix]

### Assess significance of clusters using single trials 

For this we need to extract the source activity per trials, and it is too heavy at once on all roi_idx, so we do it per vertex, then compute the CSD, extract layer-specific activity and then write it to a dictionnary, this is done in a separate python script

In [None]:
base_fname_t=os.path.join(subj_dir_sss, ses_id, f'spm/pcspm_converted_autoreject-{subj_id}-{ses_id}-{epoch}-epo.mat')

In [None]:
layer_ts, time, _ = load_source_time_series(
        base_fname_t,
        mu_matrix=MU, #we base ourselves on the inversion matrix from averaged data
        vertices=5195
    )

In [None]:
h5_filename_path = f'{out_dir_chunks}/group_data_st.h5'

In [None]:
import os
print(os.path.getsize(h5_filename_path))

In [None]:
with h5py.File(h5_filename_path, 'r+') as h5f:
    subset = h5f['v_mcsd_L5'][:]

In [None]:
for b in range(n_bins):
    plt.plot(woi_idx, subset[3, b, :, 0], label=f'bin {b}')
    plt.xlabel('time(ms) - from -0.5 to 0.5')

In [None]:
faces = ds_inflated.darrays[1].data
mask = np.all(np.isin(faces, unique_vertices[0:271]), axis=1)
roi_faces_global = faces[mask]

# Build mapping from global to local index
global_to_local = {v: i for i, v in enumerate(unique_vertices[0:271])}

# Reindex faces from global to local indices
roi_faces_local = np.array([[global_to_local[v] for v in tri] for tri in roi_faces_global])

In [None]:
roi_faces_local.shape

In [None]:
import numpy as np
from scipy.sparse import lil_matrix

n_vertices = len(unique_vertices[0:271])
adjacency = lil_matrix((n_vertices, n_vertices), dtype=int)

for tri in roi_faces_local:
    # tri is an array of 3 vertex indices [v0, v1, v2]
    for i in range(3):
        for j in range(i + 1, 3):
            v1, v2 = tri[i], tri[j]
            adjacency[v1, v2] = 1
            adjacency[v2, v1] = 1  # symmetric adjacency

Procedure: compute single trials (here bins) mean power or layer-specific power, baseline corrected -> see where it deviates from 1

In [None]:
data_stats = subset.transpose(1, 2, 0)

In [None]:
import scipy.stats
from mne.stats import spatio_temporal_cluster_1samp_test
# on tailed (as we work here with asbolute values)
tail = 1
p_threshold = 0.05

df = len(subset[0:271]) - 1
t_thresh = scipy.stats.t.ppf(1 - p_threshold, df=df)

n_permutations = 50

# Run the analysis
T_obs, clusters, cluster_p_values, H0 = clu = spatio_temporal_cluster_1samp_test(
    data_stats,
    adjacency=adjacency,
    n_jobs=None,
    threshold=t_thresh,
    buffer_size=None,
    verbose=True,
)

In [None]:
# Select the clusters that are statistically significant at p < 0.05
good_clusters_idx = np.where(cluster_p_values < 0.05)[0]
good_clusters = [clusters[idx] for idx in good_clusters_idx]

In [None]:
good_clusters_idx.shape

### Plot the time series at the maximum peaks/clusters 
Can be clusters or a list/dictionnary of (n_vertices x n_time windows) for each layer/condition

In [None]:
#define the woi you want to plot
woi_index = [0]

In [None]:
vmin_vmax = [-0.1, 0.1]

for t_ix in woi_index:
    for group_name, group_info in group_data.items():
        max_vertex_ids = group_info['clusters'][t_ix]["max_vertex_ids"]
        n_plots = len(max_vertex_ids)

        if n_plots == 0:
            continue

        # Layout: 1 row per vertex, 2 columns per row (Time Series | CSD)
        nrows = n_plots
        ncols = 2
        figsize = (12, 4 * nrows)

        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
        if n_plots == 1:
            axes = np.array([axes])  # Ensure axes is 2D even for 1 row

        for i, max_ix in enumerate(max_vertex_ids):
            # Time Series
            ax_ts = axes[i, 0]
            pial_layer_ts_mean = mean_layer_ts[max_ix, :]
            start, end = woi[t_ix]
            ax_ts.axvspan(start, end, color='gray', alpha=0.3)
            ax_ts.plot(time, pial_layer_ts_mean, color='k')
            ax_ts.set_title(f'{key} - Vertex {max_ix}\nTime Series')
            ax_ts.set_xlabel('Time (ms)')
            ax_ts.set_ylabel('Amplitude')

            # CSD
            ax_csd = axes[i, 1]
            roi_max_ix = np.where(roi_idx == max_ix)[0][0]
            csd_data = sm_csd_roi[roi_max_ix]
            plot_csd(csd_data, time, ax_csd, vmin_vmax=vmin_vmax, n_layers=n_layers)

            for pos in bb_lb_roi[roi_max_ix]:
                ax_csd.axhline(y=pos, color='b', linestyle='-.')
            start, end = woi[t_ix]
            ax_csd.axvspan(start, end, color='gray', alpha=0.3)
            ax_csd.set_title(f'CSD - Vertex {max_ix}')
            ax_csd.set_xlabel('Time (ms)')
            ax_csd.set_ylabel('Layer')

        plt.tight_layout()
        plt.show()