## Within-Subject Classification Pipeline
#### Using top networks per epoch to classify tasks (S1-S7) within a subject

In [None]:
import os, re
import glob, pickle
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from multipec.simulation_utils import direct

from pathlib import Path
from collections import defaultdict
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import label_binarize, StandardScaler
from sklearn.decomposition import PCA

Paths:

In [None]:
def find_repo_root(marker="setup.py"):
    path = Path.cwd()
    while not (path / marker).exists() and path != path.parent:
        path = path.parent
    return path

project_root = find_repo_root()

data_folder = project_root/"data/output/eeg/"
results_folder = project_root/"data/results/eeg/"
figures_folder = project_root/"data/figures/eeg/"

Helper functions: Load top networks for a subject. Also, make sure to select only the networks with seeds within the lowest percentile.

In [None]:
def lowest_percentile_pairs(pairs_dict, nodes, percentile):
    """
    Convert PEC_pairs output into real node pairs
    and select lowest percentile.
    """
    directed_pairs = direct(pairs_dict, nodes)

    if len(directed_pairs) == 0: return set()
    
    sorted_pairs = sorted(directed_pairs.items(), key=lambda x: x[1])

    k = max(1, int(len(sorted_pairs) * percentile / 100.0))

    selected = sorted_pairs[:k]

    return set(tuple(sorted(pair)) for pair, _ in selected)


In [None]:
def load_subject_top_nets(subject_id, data_folder, verbose=True, percentile=1.0):
    """
    Load networks for a subject, but KEEP ONLY networks whose seed pair
    belongs to the lowest PEC percentile in that epoch.

    Adaptive:
    start at 1% and relax by 0.5% until >= top_k networks exist.
    UNLESS the epoch naturally has <10 networks (then keep all).
    """
    files = glob.glob(os.path.join(data_folder, f"nets_*_S*_{subject_id}_epoch*.p"))
    all_networks = []

    for file in sorted(files):

        basename = os.path.basename(file)

        match = re.search(r'nets_(.+?)_(S\d+)_(\d+)_epoch(\d+)\.p$', basename)
        if not match:
            continue

        label, task, _, epoch = match.groups()
        epoch = int(epoch)
        subject = f"sub{subject_id}"

        #  load networks 
        with open(file, "rb") as f:
            nets_list = pickle.load(f)

        n_nets = len(nets_list)
        if verbose:
            print(f"{subject} {task} epoch {epoch}: {n_nets} networks in file")
        
        # CASE 1: too few networks
        if n_nets < 10:
            if verbose:
                print("   <10 networks → skipping pair filtering")
            for entry in nets_list:
                nodes_tuple, pec_values = entry
                all_networks.append({
                    'sigma': label,
                    'task': task,
                    'subject': subject,
                    'epoch': epoch,
                    'net_nodes': nodes_tuple,
                    'pec_progress': pec_values,
                    'pec': pec_values[-1] if pec_values else None,
                })
            continue

        # CASE 2 — enough networks: filter by pairs
        pairs_file = os.path.join(data_folder, f"pairs_{task}_{subject_id}_epoch{epoch}.p")

        if not os.path.exists(pairs_file):
            if verbose:
                print("   pairs file missing → keeping all networks")
            seed_pairs = None
        else:
            with open(pairs_file, "rb") as f:
                pairs = pickle.load(f)

            nodes = list(range(62))
            selected_networks = []

            while percentile <= 10.0:

                seed_pairs = lowest_percentile_pairs(pairs, nodes, percentile)
                filtered = []
                for entry in nets_list:
                    nodes_tuple, pec_values = entry

                    if len(nodes_tuple) < 2:
                        continue
                    seed = tuple(sorted(nodes_tuple[:2]))

                    if seed in seed_pairs:
                        filtered.append(entry)

                if len(filtered) >= 10:
                    selected_networks = filtered
                    break
                percentile += 0.5

            if verbose:
                print(f"   selected {len(selected_networks)} nets using {percentile:.1f}% pairs")
            nets_list = selected_networks if selected_networks else nets_list

        # store
        for entry in nets_list:
            nodes_tuple, pec_values = entry
            all_networks.append({
                'sigma': label,
                'task': task,
                'subject': subject,
                'epoch': epoch,
                'net_nodes': nodes_tuple,
                'pec_progress': pec_values,
                'pec': pec_values[-1] if pec_values else None,
            })
    return all_networks

def select_top_networks(networks, top_k=10, verbose=True):
    """
    Select top_k networks (lowest PEC) per subject per task per epoch.
    Also prints how many networks were available vs kept.
    """
    grouped = defaultdict(list)
    
    for net in networks:
        key = (net['subject'], net['task'], net['epoch'])
        grouped[key].append(net)

    filtered = []

    if verbose:
        print("\nTOP-K SELECTION ")

    for key, nets in sorted(grouped.items()):
        subject, task, epoch = key
        # remove None PEC (rare but dangerous)
        nets = [n for n in nets if n['pec'] is not None]
        # sort ascending (best PEC first)
        sorted_nets = sorted(nets, key=lambda x: x['pec'])
        best_nets = sorted_nets[:top_k]
        filtered.extend(best_nets)

        if verbose:
            print(
                f"{task} | epoch {epoch}: "
                f"{len(nets)} available -> {len(best_nets)} selected"
            )
    return filtered


Node activation map for each epoch, to obtain a single spatial functional brain pattern for brain state classification (rather than network classification).

In [None]:
def build_epoch_activation(networks_epoch, n_channels=62):
    """
    Convert all networks of ONE epoch into a node activation map.
    """
    activation = np.zeros(n_channels)

    for net in networks_epoch:
        nodes = net['net_nodes']
        w = 1.0 / (net['pec'] + 1e-6)   # IMPORTANT: smaller PEC = stronger network

        for n in nodes:
            activation[n] += w

    # z-score normalization (subject independent)
    if np.std(activation) > 0:
        activation = (activation - np.mean(activation)) / np.std(activation)

    return activation


ROI features. We extract cognitive meaning by compuoting energy inside cortical systems.

In [None]:
ROIS = {
    "visual": [27,28,29,30,33,34,35,36],
    "temporal_L": [16,17,21,22,37,42,45,47,49,51],
    "temporal_R": [18,19,23,24,38,43,46,48,50,53],
    "central": [11,12,13,14,17,18,19,20],
    "parietal": [21,22,23,24,25,26],
    "frontal": [5,6,7,8,9,54,55,56,57],
    "prefrontal": [0,1,2,3,4,58,59,60,61]
}

def compute_roi_features(activation):
    feats = []
    for roi_nodes in ROIS.values():
        feats.append(np.mean(activation[roi_nodes]))
    return np.array(feats)


Spatial centroid of brian state.

In [None]:
def activation_centroid(activation, coords):
    coords_array = np.array([coords[i] for i in range(62)])
    weights = np.abs(activation) + 1e-8
    centroid = np.average(coords_array, axis=0, weights=weights)
    return centroid


Hemispheric dominance

In [None]:
def hemispheric_balance(activation, channel_names):
    left = [i for i,ch in enumerate(channel_names) if ch.endswith(('1','3','5','7'))]
    right = [i for i,ch in enumerate(channel_names) if ch.endswith(('2','4','6','8'))]
    L = np.mean(np.abs(activation[left]))
    R = np.mean(np.abs(activation[right]))
    return L - R


64-channel 10–10 system coordinates (x, y, z on unit sphere) from https://github.com/sccn/eeglab/blob/master/sample_locs/GSN64v2_0.sfp. A1 (31) and A2 (32) electrodes (references) were removed during preprocessing, so there is 62 total electrodes.

In [None]:
coords_64 = {
    0: (4.82147, 8.46376, -0.0639843),
    1: (3.44999, 9.06441, 2.97064),
    2: (1.90275, 7.64198, 6.40285),
    3: (0.0, 5.21914, 8.19540),
    4: (-1.72628, 2.48578, 8.55637),
    5: (1.59679, 10.5184, 0.515771),
    6: (0.0, 9.90900, 4.03918),
    7: (-1.90275, 7.64198, 6.40285),
    8: (-3.44910, 4.88753, 7.22782),
    9: (0.0, 10.2056, -1.61817),
    10: (-1.59679, 10.5184, 0.515771),
    11: (-3.44999, 9.06441, 2.97064),
    12: (-4.77336, 6.69969, 4.57236),
    13: (-4.82147, 8.46376, -0.0639843),
    14: (-6.08393, 6.13847, 1.57273),
    15: (-6.30320, 3.70148, 4.22719),
    16: (-5.28740, 1.30272, 6.65006),
    17: (-2.89239, -1.49571, 8.18118),
    18: (-6.49644, 4.76669, -1.91820),
    19: (-7.25734, 2.59912, 0.903444),
    20: (-7.34336, 0.287481, 3.19674),
    21: (-5.94121, -2.42769, 5.97948),
    22: (-6.70321, 2.34743, -5.19168),
    23: (-7.64626, -1.18501, 0.271229),
    24: (-7.24521, -2.89502, 2.68804),
    25: (-7.06353, -3.24230, -2.70838),
    26: (-6.78539, -4.61900, -0.234471),
    27: (-5.98166, -5.75781, 2.97462),
    28: (-3.38310, -5.87378, 6.25655),
    29: (0.0, -4.05271, 8.13937),
    30: (-5.64139, -6.07135, -3.51430),
    31: (-4.88083, -7.67925, -0.352208),
    32: (-3.16769, -8.12261, 3.13635),
    33: (0.0, -7.14834, 5.87509),
    34: (-4.27348, -6.69422, -6.35512),
    35: (-3.08195, -8.59131, -3.53248),
    36: (-1.79973, -9.42935, -0.256676),
    37: (0.0, -8.99684, 2.56482),
    38: (0.0, -9.12413, -3.87835),
    39: (1.79973, -9.42935, -0.256676),
    40: (3.16769, -8.12261, 3.13635),
    41: (3.38310, -5.87378, 6.25655),
    42: (2.89239, -1.49571, 8.18118),
    43: (3.08195, -8.59131, -3.53248),
    44: (4.88083, -7.67925, -0.352208),
    45: (5.98166, -5.75781, 2.97462),
    46: (5.94121, -2.42769, 5.97948),
    47: (5.64139, -6.07135, -3.51430),
    48: (6.78539, -4.61900, -0.234471),
    49: (7.24521, -2.89502, 2.68804),
    50: (7.06353, -3.24230, -2.70838),
    51: (7.64626, -1.18501, 0.271229),
    52: (7.34336, 0.287481, 3.19674),
    53: (5.28740, 1.30272, 6.65006),
    54: (1.72628, 2.48578, 8.55637),
    55: (7.25734, 2.59912, 0.903444),
    56: (6.30320, 3.70148, 4.22719),
    57: (3.44910, 4.88753, 7.22782),
    58: (6.70321, 2.34743, -5.19168),
    59: (6.49644, 4.76669, -1.91820),
    60: (6.08393, 6.13847, 1.57273),
    61: (4.77336, 6.69969, 4.57236),
    62: (3.82315, 7.90175, -7.06073),
    63: (-3.82315, 7.90175, -7.06073)
}

def exclude_and_reindex_channels(coords_dict, exclude_indices=[31, 32]):
    """
    Remove specific channels and reindex the remaining channels.
    
    Parameters:
    - coords_dict (dict): Original channel coordinates {idx: (x, y, z)}
    - exclude_indices (list): List of indices to remove
    
    Returns:
    - new_coords (dict): Reindexed channel coordinates
    """
    # Remove excluded channels
    filtered_items = [(idx, coord) for idx, coord in coords_dict.items() if idx not in exclude_indices]
    
    # Reindex from 0
    new_coords = {new_idx: coord for new_idx, (_, coord) in enumerate(filtered_items)}
    
    return new_coords

coords_62 = exclude_and_reindex_channels(coords_64, exclude_indices=[31, 32])

# Hemisphere map
hemisphere_62 = {i: 'L' if i % 2 == 0 else 'R' for i in range(62)}  # simplification

channel_names = [
    "Fp1", "Fp2", "F7", "F3", "Fz", "F4", "F8", "FC5", "FC1", "FC2", "FC6", "T7",
    "C3", "C4", "T8", "TP9", "CP5", "CP1", "CP2", "CP6", "TP10", "P7", "P3", "Pz",
    "P4", "P8", "O1", "Oz", "O2", "Iz", "AF7", "AF3", "AFz", "AF4",
    "AF8", "F5", "F1", "F2", "F6", "FT7", "FC3", "FCz", "FC4", "FT8", "C5", "C1",
    "C2", "C6", "TP7", "CP3", "CPz", "CP4", "TP8", "P5", "P1", "P2", "P6", "PO7",
    "PO3", "POz", "PO4", "PO8"
]

Function for building feature matrix X and labels y.

In [None]:
def build_epoch_dataset(subject_networks, coords, hemisphere_map):

    X = []
    y = []
    groups = []

    # group networks by epoch+task
    epoch_groups = defaultdict(list)

    for net in subject_networks:
        key = (net['task'], net['epoch'])
        epoch_groups[key].append(net)

    for (task, epoch), nets in epoch_groups.items():

        activation = build_epoch_activation(nets)

        roi_feats = compute_roi_features(activation)
        centroid = activation_centroid(activation, coords)
        hemi = hemispheric_balance(activation, hemisphere_map)

        feature_vector = np.concatenate([
            activation,      # 62
            roi_feats,       # 7
            centroid,        # 3
            [hemi]           # 1
        ])

        X.append(feature_vector)
        y.append(task)
        groups.append(epoch)

    X = np.array(X)
    y = np.array(y)
    groups = np.array(groups)

    # standardize
    # comment out if you want to keep original activation values (for visualization or non-linear models)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    return X, y, groups


Leave-One-Epoch-Out (LOEO) classification function.

In [None]:
def classify_epochs(X, y, groups):

    logo = LeaveOneGroupOut()

    model = Pipeline([
        ("scaler", StandardScaler()),
        ("pca", PCA(n_components=0.95)), # retain 95% variance
        ("clf", LogisticRegression(
            max_iter=3000,
            class_weight='balanced',
            solver='lbfgs'
        ))
    ])

    all_true = []
    all_pred = []
    all_prob = []

    for train_idx, test_idx in logo.split(X, y, groups):

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # IMPORTANT: pipeline is fit ONLY on training data
        model.fit(X_train, y_train)

        pred = model.predict(X_test)
        prob = model.predict_proba(X_test)

        all_true.extend(y_test)
        all_pred.extend(pred)
        all_prob.extend(prob)

    print(classification_report(all_true, all_pred))

    Y_bin = label_binarize(all_true, classes=np.unique(y))
    auc = roc_auc_score(Y_bin, np.array(all_prob), multi_class='ovr')
    print("Overall ROC-AUC:", auc)


Load subject data.

In [None]:
subject_id = '01'
tasks = [f"S{i}" for i in range(1, 8)]  # S1-S7
epochs = list(range(10))  # 0-9

# 1) Load and select networks
subject_networks = select_top_networks(
    load_subject_top_nets(subject_id, data_folder, percentile=1.0),
    top_k=10
)

Build classification dataset.

In [None]:
# 2) Build dataset at EPOCH LEVEL
X, y, groups = build_epoch_dataset(
    subject_networks,
    coords=coords_62,
    hemisphere_map=channel_names
)
print("Dataset shape:", X.shape)
print("Samples per task:")
from collections import Counter
print(Counter(y))


### Classification

In [None]:
# 3) Classify
classify_epochs(X, y, groups)

### Visualization and interpretation

Plotting parameters.

In [40]:
feature_names = (
    [f"activation_{i}" for i in range(62)] +
    [f"roi: {lobe}" for lobe in ROIS.keys()] +
    ["centroid_x", "centroid_y", "centroid_z"] +
    ["hemispheric_balance"]
)

TASK_ORDER = ["S1","S2","S3","S4","S5","S6","S7"]

TASK_COLORS = {
    "S1": "#08306b",   # dark blue
    "S2": "#2171b5",   # medium blue
    "S3": "#6baed6",   # light blue

    "S4": "#7f2704",   # dark orange
    "S5": "#d94801",   # medium orange
    "S6": "#fdae6b",   # light orange

    "S7": "#6b6b6b"    # gray
}


Build feature dataframe. X and y need to be non-standardized (comment out `StandardScaler` from `build_epoch_dataset`).

In [None]:
def build_feature_dataframe(X, y, groups, feature_names):
    df = pd.DataFrame(X, columns=feature_names)
    df["task"] = y
    df["epoch"] = groups
    return df


df = build_feature_dataframe(X, y, groups, feature_names)

df[feature_names].std().sort_values()


roi: prefrontal    0.161889
roi: temporal_R    0.171981
roi: visual        0.193674
roi: temporal_L    0.197799
roi: frontal       0.202298
                     ...   
activation_43      1.306552
activation_6       1.314970
activation_48      1.373625
activation_10      1.499552
activation_16      1.546583
Length: 73, dtype: float64

Plot a specific feature.

In [None]:
def plot_feature(df, feature):

    plt.figure(figsize=(8, 4))

    tasks = [t for t in TASK_ORDER if t in df["task"].unique()]

    for task in tasks:
        sub = df[df["task"] == task]
        sub = sub.sort_values("epoch")

        plt.plot(
            sub["epoch"],
            sub[feature],
            marker="o",
            linewidth=1.8,
            markersize=4,
            color=TASK_COLORS.get(task, "black"),
            label=task
        )

    plt.xlabel("Epoch")
    plt.ylabel(feature)
    plt.title(feature)
    plt.legend(ncol=4, fontsize=8)
    plt.tight_layout()
    plt.show()


plot_feature(df, "hemispheric_balance")
# plot_feature(df, "centroid_x")
# plot_feature(df, "roi_3")
# plot_feature(df, "activation_12")

Plot the distribution of a specific feature.

In [None]:
def plot_feature_distribution(df, feature):

    plt.figure(figsize=(6, 4))

    tasks = [t for t in TASK_ORDER if t in df["task"].unique()]

    for task in tasks:
        vals = df[df["task"] == task][feature]

        plt.hist(
            vals,
            bins=20,
            alpha=0.5,
            color=TASK_COLORS.get(task, "black"),
            label=task,
            density=True
        )

    plt.title(feature)
    plt.legend(fontsize=8)
    plt.tight_layout()
    plt.show()


plot_feature_distribution(df, "hemispheric_balance")
# plot_feature_distribution(df, "centroid_x")
# plot_feature_distribution(df, "roi_3")
# plot_feature_distribution(df, "activation_12")

Plot all or groups of features at once.

In [None]:
activation_feats = [f for f in feature_names if f.startswith("activation")]
roi_feats = [f for f in feature_names if f.startswith("roi")]
global_feats = ["centroid_x", "centroid_y", "centroid_z", "hemispheric_balance"]

def plot_all_features(df, feature_names):
    for feat in feature_names:
        plot_feature(df, feat)

plot_all_features(df, activation_feats)
