<a href="https://colab.research.google.com/github/safari-mohammadreza/MWL_GCN_GAT/blob/main/GCN_MWL_Paper_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialize

In [None]:
!pip install torch_geometric torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
# !pip install optuna
# !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \
#     -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
# !pip install captum

In [None]:
# Standard library
import os
import pickle
import warnings

# Numerical & data handling
import numpy as np
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

# PyTorch core
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorch Geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.utils import to_networkx

# Scikit‑learn
from sklearn.model_selection import KFold
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    roc_auc_score,
    precision_recall_curve,
)

# Hyperparameter tuning
# import optuna

In [None]:
warnings.filterwarnings("ignore")

In [None]:
# variables

path='/content/gdrive/My Drive/Thesis/connectivities_mat_files/'
save_path='/content/gdrive/My Drive/GCN/save'
save_fig='/content/gdrive/My Drive/GCN/figure'

bands=['delta','theta','alpha','beta','gamma']
channels= ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
category_labels = ['low', 'high']

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# functions

def prepare_graphs(data_dict, labels, bands, cache_path=None):
    """
    Prepare and cache graph.Data objects from connectivity matrices.
    Args:
        data_dict (dict): {band: np.ndarray of shape (N, C, C)}
        labels (dict): {band: np.ndarray of shape (N,)}
        bands (list): list of band names to include
        cache_path (str): optional path to pickle cache
    Returns:
        List[Data]
    """
    if cache_path and os.path.exists(cache_path):
        return pickle.load(open(cache_path, 'rb'))
    graphs = []
    for band in bands:
        arr = data_dict[band]
        lab = labels[band]
        for i in range(arr.shape[0]):
            adj = arr[i].copy()
            np.fill_diagonal(adj, 0)
            edge_index = torch.tensor(np.vstack(np.where(adj > 0)), dtype=torch.long)
            edge_attr = torch.tensor(adj[adj > 0], dtype=torch.float)
            x = torch.tensor(adj, dtype=torch.float32)
            y = torch.tensor([lab[i]], dtype=torch.long)
            graphs.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
    if cache_path:
        pickle.dump(graphs, open(cache_path, 'wb'))
    return graphs


def plot_connectivity_sample(data_dict, bands, channel_labels, categories):
    """
    Plot a single example connectivity matrix per band and category.
    """
    for band in bands:
        for cat in categories:
            key = f"{band}_{cat}"
            if key not in data_dict:
                continue
            mat = data_dict[key][0].copy()
            np.fill_diagonal(mat, 0)

            plt.figure(figsize=(6,6))
            sns.heatmap(mat, cmap='viridis', xticklabels=channel_labels,
                        yticklabels=channel_labels, square=True)
            plt.title(f"{band.capitalize()} Connectivity ({cat})")
            plt.tight_layout()
            plt.show()


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool

class AblationGCN_GAT(nn.Module):
    """
    GNN model with optional GAT, residual, and dual pooling ablations.
    Now accepts `edge_attr` (mapped to edge_weight) and ignores extra kwargs.
    """
    def __init__(self, in_ch, hid_ch, out_ch,
                 use_gat=True, use_res=True, use_max=True,
                 fc1=128, dropout=0.3, heads=4):
        super().__init__()
        self.use_gat, self.use_res, self.use_max = use_gat, use_res, use_max

        self.gcn1 = GCNConv(in_ch, hid_ch)
        self.bn1  = nn.BatchNorm1d(hid_ch)

        if use_gat:
            self.gat1 = GATConv(hid_ch, hid_ch, heads=heads, concat=False)
            self.bn2  = nn.BatchNorm1d(hid_ch)

        self.gcn2 = GCNConv(hid_ch, hid_ch)

        pooled_dim = hid_ch * (1 + int(use_max))
        self.fc1  = nn.Linear(pooled_dim, fc1)
        self.drop = nn.Dropout(dropout)
        self.fc2  = nn.Linear(fc1, out_ch)

    def forward(self, x, edge_index, batch=None,
                edge_weight=None, edge_attr=None, **kwargs):
        # If the explainer passed `edge_attr`, treat it as edge_weight
        if edge_attr is not None and edge_weight is None:
            edge_weight = edge_attr

        # First GCN layer
        x = self.gcn1(x, edge_index, edge_weight)
        x = F.relu(self.bn1(x))
        x = self.drop(x)

        # Optional GAT layer
        if self.use_gat:
            x = self.gat1(x, edge_index)
            x = F.relu(self.bn2(x))
            x = self.drop(x)

        # Residual GCN
        res = self.gcn2(x, edge_index, edge_weight)
        x   = x + res if self.use_res else res

        # Global pooling (need `batch` for batching multiple graphs)
        if batch is None:
            # assume a single graph: all nodes belong to batch 0
            batch = x.new_zeros(x.size(0), dtype=torch.long)

        m = global_mean_pool(x, batch)
        if self.use_max:
            M = global_max_pool(x, batch)
            m = torch.cat([m, M], dim=1)

        # MLP head
        x = F.relu(self.fc1(m))
        x = self.drop(x)
        out = self.fc2(x)
        return F.log_softmax(out, dim=1)



def objective(trial, graphs):
    """
    Optuna objective for hyperparameter tuning.
    """
    in_ch = graphs[0].x.size(1)
    hidden = trial.suggest_categorical('hidden', [64,128,256])
    dropout = trial.suggest_float('dropout', 0.1, 0.5)
    lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
    n_splits = trial.suggest_int('folds', 8, 11)

    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    accs = []
    for train_idx, val_idx in kf.split(graphs):
        train = [graphs[i] for i in train_idx]
        val   = [graphs[i] for i in val_idx]
        tl = DataLoader(train, batch_size=32, shuffle=True)
        vl = DataLoader(val, batch_size=32)

        model = AblationGCN_GAT(in_ch, hidden, 2,
                                 use_gat=True, use_res=True, use_max=True,
                                 fc1=hidden, dropout=dropout).to('cuda')
        opt = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()

        best=0; patience=0
        for epoch in range(50):
            model.train()
            for b in tl:
                b=b.to('cuda'); opt.zero_grad()
                out = model(b.x,b.edge_index,b.batch)
                loss_fn(out,b.y).backward(); opt.step()
            model.eval()
            preds, trues = [], []
            with torch.no_grad():
                for b in vl:
                    b=b.to('cuda')
                    o=model(b.x,b.edge_index,b.batch)
                    preds+=o.argmax(1).cpu().tolist()
                    trues+=b.y.cpu().tolist()
            acc = accuracy_score(trues,preds)
            if acc>best: best,patience=acc,0
            else: patience+=1
            if patience>5: break
        accs.append(best)
    return np.mean(accs)

class GCN_GAT_Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5, heads=4):
        super(GCN_GAT_Model, self).__init__()
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.dropout = dropout
        self.gat1 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
        self.bn2 = nn.BatchNorm1d(hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)
        self.fc1 = nn.Linear(hidden_channels, 64)
        self.fc2 = nn.Linear(64, out_channels)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1) First GCN layer
        x = self.gcn1(x, edge_index, edge_weight=edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        # 2) GAT layer
        x = self.gat1(x, edge_index)  # GATConv doesn't support edge_weight by default
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        # 3) Second GCN layer
        x = self.gcn2(x, edge_index)
        # 4) Pooling
        x = global_mean_pool(x, batch)
        # 5) Final MLP layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

#####################################################
# Enhanced Model Definition with Residual & Dual Pooling
#####################################################
class EnhancedGCN_GAT_Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, fc1_size, dropout=0.3, heads=4):
        super(EnhancedGCN_GAT_Model, self).__init__()
        # First GCN layer
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.dropout = dropout
        # GAT layer
        self.gat1 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
        self.bn2 = nn.BatchNorm1d(hidden_channels)
        # Second GCN layer
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)
        # Final FC layers after dual pooling (concatenating mean and max pool)
        self.fc1 = nn.Linear(2 * hidden_channels, fc1_size)
        self.fc2 = nn.Linear(fc1_size, out_channels)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1) First GCN layer
        x = self.gcn1(x, edge_index, edge_weight=edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        # 2) GAT layer
        x = self.gat1(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        # 3) Second GCN layer with Residual Connection
        x2 = self.gcn2(x, edge_index)
        x = x + x2  # Residual addition
        # 4) Dual Pooling: mean and max pooling
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        x = torch.cat([mean_pool, max_pool], dim=1)
        # 5) Final FC layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
# load connectivity matrices
with open(path+'hi_delta', 'rb') as f:
    hi_delta = pickle.load(f)
print(np.shape(hi_delta)) # (1776, 14, 14)
with open(path+'lo_delta', 'rb') as f:
    lo_delta = pickle.load(f)
print(np.shape(lo_delta))

with open(path+'hi_theta', 'rb') as f:
    hi_theta = pickle.load(f)
print(np.shape(hi_theta))
with open(path+'lo_theta', 'rb') as f:
    lo_theta = pickle.load(f)
print(np.shape(lo_theta))

with open(path+'hi_alpha', 'rb') as f:
    hi_alpha = pickle.load(f)
print(np.shape(hi_alpha))
with open(path+'lo_alpha', 'rb') as f:
    lo_alpha = pickle.load(f)
print(np.shape(lo_alpha))

with open(path+'hi_beta', 'rb') as f:
    hi_beta = pickle.load(f)
print(np.shape(hi_beta))
with open(path+'lo_beta', 'rb') as f:
    lo_beta = pickle.load(f)
print(np.shape(lo_beta))

with open(path+'hi_gamma', 'rb') as f:
    hi_gamma = pickle.load(f)
print(np.shape(hi_gamma))
with open(path+'lo_gamma', 'rb') as f:
    lo_gamma = pickle.load(f)
print(np.shape(lo_gamma))

In [None]:
data = [hi_delta, lo_delta, hi_theta, lo_theta, hi_alpha, lo_alpha, hi_beta, lo_beta, hi_gamma, lo_gamma]
print(np.shape(data))

# Optimization

In [None]:
#####################################################
# Load Your Data
#####################################################
# hi_delta, lo_delta, hi_theta, etc. should be preloaded numpy arrays.
data_dict = {
    "delta": np.concatenate((hi_delta, lo_delta)),
    "theta": np.concatenate((hi_theta, lo_theta)),
    "alpha": np.concatenate((hi_alpha, lo_alpha)),
    "beta": np.concatenate((hi_beta, lo_beta)),
    "gamma": np.concatenate((hi_gamma, lo_gamma)),
}

labels = {}
for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
    num_total = data_dict[band].shape[0]
    hi = data_dict[band][:num_total // 2]
    lo = data_dict[band][num_total // 2:]
    labels[band] = np.concatenate((np.ones(hi.shape[0]), np.zeros(lo.shape[0]))).astype(int)
    print(f"Label distribution for {band}:", np.bincount(labels[band]))

bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']
graphs = prepare_graphs(data_dict, labels, bands)

#####################################################
# Run Hyperparameter Optimization with Optuna
#####################################################
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=15)  # Adjust number of trials as needed

# Print the best trial and its hyperparameters:
print("Best trial:")
best_trial = study.best_trial
print(f"  Value: {best_trial.value}")
print("  Params:")
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")

#####################################################
# Optionally, Plot the Optimization History
#####################################################
optuna.visualization.plot_optimization_history(study)
plt.show()


In [None]:
#####################################################
# Load Your Data
#####################################################
# Each shape: (1776, 14, 14)
# hi_delta, lo_delta, hi_theta, etc. should be preloaded numpy arrays.
data_dict = {
    "delta": np.concatenate((hi_delta, lo_delta)),
    "theta": np.concatenate((hi_theta, lo_theta)),
    "alpha": np.concatenate((hi_alpha, lo_alpha)),
    "beta": np.concatenate((hi_beta, lo_beta)),
    "gamma": np.concatenate((hi_gamma, lo_gamma)),
}

labels = {}
for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
    num_total = data_dict[band].shape[0]
    hi = data_dict[band][:num_total // 2]
    lo = data_dict[band][num_total // 2:]
    labels[band] = np.concatenate((np.ones(hi.shape[0]), np.zeros(lo.shape[0]))).astype(int)
    print(f"Label distribution for {band}:", np.bincount(labels[band]))

bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']
graphs = prepare_graphs(data_dict, labels, bands)

#####################################################
# Run Hyperparameter Optimization with Optuna
#####################################################
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=15)  # Adjust number of trials as needed

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

#####################################################
# Optionally, Plot the Optimization History
#####################################################
optuna.visualization.plot_optimization_history(study)
plt.show()


# Train

In [None]:
#####################################################
# Load Data
#####################################################
# hi_delta, lo_delta, hi_theta, lo_theta, hi_alpha, lo_alpha, hi_beta, lo_beta, hi_gamma, lo_gamma
# should be preloaded numpy arrays of shape (1776, 14, 14) each.
data_dict = {
    "delta": np.concatenate((hi_delta, lo_delta)),
    "theta": np.concatenate((hi_theta, lo_theta)),
    "alpha": np.concatenate((hi_alpha, lo_alpha)),
    "beta": np.concatenate((hi_beta, lo_beta)),
    "gamma": np.concatenate((hi_gamma, lo_gamma)),
}

labels = {}
for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
    num_total = data_dict[band].shape[0]
    hi = data_dict[band][:num_total // 2]
    lo = data_dict[band][num_total // 2:]
    labels[band] = np.concatenate((np.ones(hi.shape[0]), np.zeros(lo.shape[0]))).astype(int)
    print(f"Label distribution for {band}:", np.bincount(labels[band]))

bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']
graphs = prepare_graphs(data_dict, labels, bands)

#####################################################
# Cross-Validation and Training
#####################################################

# Use 11-fold cross-validation
kf = KFold(n_splits=13, shuffle=True, random_state=42)
num_epochs = 120  # Increased epochs
early_stopping_patience = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_acc_history_all = []
val_acc_history_all = []
fold_preds_all = []
fold_trues_all = []

# Using tuned hyperparameters
hidden_channels = 512
fc1_size = 128
dropout = 0.3
learning_rate = 0.0007
weight_decay = 1e-4  # added weight decay

for fold, (train_idx, test_idx) in enumerate(kf.split(graphs)):
    print(f"\n=== Fold {fold + 1} ===")
    train_set = [graphs[i] for i in train_idx]
    test_set = [graphs[i] for i in test_idx]
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=32)

    model = GCN_GAT_Model(
        in_channels=14,
        hidden_channels=hidden_channels,
        out_channels=2,
        fc1_size=fc1_size,
        dropout=dropout
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    # Use a scheduler to reduce LR when validation plateaus
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5, verbose=True)

    # For balanced data
    class_weights = torch.tensor([1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)

    best_val_acc = 0
    patience = 0
    train_acc_history = []
    val_acc_history = []
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        train_preds, train_trues = [], []
        for batch in train_loader:
            batch = batch.to(device)
            # Optionally: toggle augmentation here if desired; currently using original batch.
            batch_used = batch
            optimizer.zero_grad()
            out = model(batch_used.x, batch_used.edge_index, batch_used.batch, edge_weight=batch_used.edge_attr)
            loss = loss_fn(out, batch_used.y)
            loss.backward()
            optimizer.step()

            train_preds.extend(out.argmax(dim=1).cpu().numpy())
            train_trues.extend(batch_used.y.cpu().numpy())
        train_acc = accuracy_score(train_trues, train_preds)

        model.eval()
        val_preds, val_trues = [], []
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch, edge_weight=batch.edge_attr)
                val_preds.extend(out.argmax(dim=1).cpu().numpy())
                val_trues.extend(batch.y.cpu().numpy())
        val_acc = accuracy_score(val_trues, val_preds)
        scheduler.step(val_acc)

        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            patience = 0
        else:
            patience += 1
        if patience > early_stopping_patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

        print(f"Epoch {epoch + 1}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    final_val_preds, final_val_trues = [], []
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch, edge_weight=batch.edge_attr)
            final_val_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_val_trues.extend(batch.y.cpu().numpy())

    train_acc_history_all.append(train_acc_history)
    val_acc_history_all.append(val_acc_history)
    fold_preds_all.append(final_val_preds)
    fold_trues_all.append(final_val_trues)

#####################################################
# Plot Per-Fold Accuracy Curves and Confusion Matrices
#####################################################
fig, axs = plt.subplots(nrows=len(train_acc_history_all), ncols=2, figsize=(12, 4 * len(train_acc_history_all)))

for i, (train_accs, val_accs) in enumerate(zip(train_acc_history_all, val_acc_history_all)):
    epochs = range(1, len(train_accs) + 1)
    axs[i, 0].plot(epochs, train_accs, label='Train Acc')
    axs[i, 0].plot(epochs, val_accs, label='Val Acc')
    axs[i, 0].set_xlabel('Epoch')
    axs[i, 0].set_ylabel('Accuracy')
    axs[i, 0].set_title(f'Fold {i+1} Accuracy')
    axs[i, 0].legend()
    cm = confusion_matrix(fold_trues_all[i], fold_preds_all[i])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Low', 'High'],
                yticklabels=['Low', 'High'],
                ax=axs[i, 1])
    axs[i, 1].set_title(f'Fold {i+1} Confusion Matrix')
    axs[i, 1].set_xlabel('Predicted')
    axs[i, 1].set_ylabel('Actual')

plt.tight_layout()
plt.show()

#####################################################
# Overall Results Across Folds
#####################################################
all_preds = np.concatenate(fold_preds_all, axis=0)
all_labels = np.concatenate(fold_trues_all, axis=0)

accuracy_final = accuracy_score(all_labels, all_preds)
precision_final = precision_score(all_labels, all_preds)
recall_final = recall_score(all_labels, all_preds)
f1_final = f1_score(all_labels, all_preds)
auc_final = roc_auc_score(all_labels, all_preds)

print(f"\n=== Final Metrics Across All Folds ===")
print(f"Accuracy: {accuracy_final:.4f}")
print(f"Precision: {precision_final:.4f}")
print(f"Recall: {recall_final:.4f}")
print(f"F1 Score: {f1_final:.4f}")
print(f"AUC Score: {auc_final:.4f}")

precisions, recalls, _ = precision_recall_curve(all_labels, all_preds)
plt.figure(figsize=(6, 5))
plt.plot(recalls, precisions, marker='.')
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Overall Precision-Recall Curve")
plt.show()

conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Low', 'High'], yticklabels=['Low', 'High'])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Overall Confusion Matrix")
plt.show()

#####################################################
# Plot Average Training/Validation Accuracy
#####################################################
max_epochs = max(len(acc) for acc in train_acc_history_all)
train_acc_padded = np.array([np.pad(acc, (0, max_epochs - len(acc)), 'edge')
                             for acc in train_acc_history_all])
val_acc_padded = np.array([np.pad(acc, (0, max_epochs - len(acc)), 'edge')
                           for acc in val_acc_history_all])
train_acc_mean = np.mean(train_acc_padded, axis=0)
val_acc_mean = np.mean(val_acc_padded, axis=0)

plt.figure(figsize=(10, 5))
plt.plot(train_acc_mean, label='Training Accuracy')
plt.plot(val_acc_mean, label='Validation Accuracy')
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Average Training and Validation Accuracy Curve Across Folds")
plt.legend()
plt.show()


In [None]:
# Ablation purpose and per band

# Assume hi_delta, lo_delta, etc. are loaded numpy arrays of shape (1776,14,14)
bands = ['delta','theta','alpha','beta','gamma']
data_dict = {
    b: np.concatenate((globals()[f"hi_{b}"], globals()[f"lo_{b}"]))
    for b in bands
}
labels = {}
for b in bands:
    arr = data_dict[b]
    halfway = arr.shape[0] // 2
    labels[b] = np.concatenate((np.ones(halfway), np.zeros(halfway))).astype(int)

# Create dirs
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('plots', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------
# 3) Main training loop with ablations + history
# ------------------------------------------------------
scenarios = bands + ['all']
ablations = [
    dict(name='full',      use_gat=True,  use_res=True,  use_max=True),
    dict(name='no_gat',    use_gat=False, use_res=True,  use_max=True),
    dict(name='no_res',    use_gat=True,  use_res=False, use_max=True),
    dict(name='mean_only', use_gat=True,  use_res=True,  use_max=False),
]

history = {}   # key -> {'train': [list per fold], 'val': [list per fold]}
results = {}   # key -> mean accuracy

for scenario in scenarios:
    cache = f'checkpoints/graphs_{scenario}.pkl'
    if scenario=='all':
        gs = prepare_graphs(data_dict, labels, bands, cache)
    else:
        gs = prepare_graphs({scenario:data_dict[scenario]},
                            {scenario:labels[scenario]},
                            [scenario], cache)

    for abl in ablations:
        key = f"{scenario}/{abl['name']}"
        history[key] = {'train':[], 'val':[]}

        ckpt_w = f'checkpoints/{scenario}_{abl["name"]}.pt'
        ckpt_a = ckpt_w + '_acc'
        if os.path.exists(ckpt_w) and os.path.exists(ckpt_a):
            results[key] = torch.load(ckpt_a)
            print(f"Skipping {key}, loaded acc {results[key]:.4f}")
            continue

        fold_acc = []
        kf = KFold(n_splits=11, shuffle=True, random_state=42)

        for fold, (ti,vi) in enumerate(kf.split(gs),1):
            train_set = [gs[i] for i in ti]
            test_set  = [gs[i] for i in vi]
            train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
            test_loader  = DataLoader(test_set,  batch_size=32)

            model = AblationGCN_GAT(14,256,2,128,0.274,4,
                                    use_gat=abl['use_gat'],
                                    use_res=abl['use_res'],
                                    use_max=abl['use_max']).to(device)
            opt     = torch.optim.Adam(model.parameters(), lr=7e-4, weight_decay=1e-4)
            sched   = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=3)
            loss_fn = nn.CrossEntropyLoss()

            best,pat = 0,0
            train_hist, val_hist = [], []

            for ep in range(1,121):
                # --- train ---
                model.train()
                train_preds, train_trues = [], []
                for b in train_loader:
                    b = b.to(device)
                    opt.zero_grad()
                    out = model(b.x, b.edge_index, b.batch)
                    loss = loss_fn(out, b.y)
                    loss.backward(); opt.step()
                    train_preds.extend(out.argmax(1).cpu().numpy())
                    train_trues.extend(b.y.cpu().numpy())
                train_acc = accuracy_score(train_trues, train_preds)

                # --- val ---
                model.eval()
                val_preds, val_trues = [], []
                with torch.no_grad():
                    for b in test_loader:
                        b = b.to(device)
                        out = model(b.x, b.edge_index, b.batch)
                        val_preds.extend(out.argmax(1).cpu().numpy())
                        val_trues.extend(b.y.cpu().numpy())
                val_acc = accuracy_score(val_trues, val_preds)
                sched.step(val_acc)

                # record histories
                train_hist.append(train_acc)
                val_hist.append(val_acc)

                # early‑stop & checkpoint
                if val_acc > best:
                    best, pat = val_acc, 0
                    torch.save(model.state_dict(), ckpt_w)
                else:
                    pat += 1
                if pat > 10:
                    break

            history[key]['train'].append(train_hist)
            history[key]['val'].append(val_hist)
            fold_acc.append(best)
            print(f"{key} fold{fold} best-val-acc={best:.4f}")

        mean_acc = float(np.mean(fold_acc))
        results[key] = mean_acc
        torch.save(mean_acc, ckpt_a)
        print(f"=> {key} mean acc={mean_acc:.4f}")

# 4) Save history for later plotting
with open('checkpoints/history.pkl','wb') as f:
    pickle.dump(history, f)

# ----------------------------------------------------------------
# 5) Plot bar chart of overall scenario accuracies
# ----------------------------------------------------------------
import pandas as pd
df = pd.DataFrame.from_dict(results, orient='index', columns=['Mean Accuracy'])
df.index.name = 'Scenario'
df = df.reset_index().sort_values('Mean Accuracy', ascending=False)

plt.figure(figsize=(12,6))
sns.barplot(data=df, x='Scenario', y='Mean Accuracy', palette='magma')
plt.xticks(rotation=45, ha='right'); plt.ylim(0,1)
plt.title('Validation Accuracy per Scenario & Ablation')
plt.tight_layout()
plt.savefig('plots/accuracy_bar.png')
plt.show()

# ----------------------------------------------------------------
# 6) Plot train/val accuracy curves for each key
# ----------------------------------------------------------------
with open('checkpoints/history.pkl','rb') as f:
    history = pickle.load(f)

for key, h in history.items():
    if not h['val'] or not h['train']:
        continue

    max_ep = max(max(len(l) for l in h['train']),
                 max(len(l) for l in h['val']))
    train_arr = np.array([np.pad(l, (0, max_ep-len(l)), 'edge') for l in h['train']])
    val_arr   = np.array([np.pad(l, (0, max_ep-len(l)), 'edge') for l in h['val']])

    tm = train_arr.mean(axis=0)
    vm = val_arr.mean(axis=0)

    plt.figure(figsize=(6,4))
    plt.plot(range(1, len(tm)+1), tm, label='Train Acc')
    plt.plot(range(1, len(vm)+1), vm, label='Val Acc')
    plt.title(f'Train vs Val — {key}')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.legend(); plt.tight_layout()
    plt.savefig(f'plots/{key.replace("/","_")}_curve.png')
    plt.close()


# Interpretability

In [None]:
import os
import pickle
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
from torch_geometric.explain.algorithm import GNNExplainer
from torch_geometric.explain import Explainer

# 1) Paths on your Drive
save_path    = '/content/gdrive/My Drive/GCN/save/checkpoints'
model_path   = os.path.join(save_path, 'all_full.pt')
graph_path   = os.path.join(save_path, 'graphs_all.pkl')

# 2) Validate existence
assert os.path.isfile(model_path), f"Missing model file: {model_path}"
assert os.path.isfile(graph_path), f"Missing graph cache: {graph_path}"
print("✅ Found both checkpoint files.")

# 3) Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#                        in  hid out   gat    res   max  fc1   drop heads
model = AblationGCN_GAT( 14, 256,  2,  True, True, True, 128, 0.274,  4 ).to(device)
state = torch.load(model_path, map_location=device)
print("State dict keys (first 5):", list(state.keys())[:5], "…")
model.load_state_dict(state)
model.eval()

# 4) Load graph cache
with open(graph_path, 'rb') as f:
    graphs = pickle.load(f)
assert graphs, "Graph cache is empty!"
data = graphs[0].to(device)
print(f"Loaded graph: x{tuple(data.x.shape)}, edges={data.edge_index.size(1)}")

# 5) Build the high‑level Explainer
explainer = Explainer(
    model            = model,
    algorithm        = GNNExplainer(epochs=200),
    explanation_type = 'model',
    node_mask_type   = 'object',
    edge_mask_type   = 'object',
    model_config     = dict(
        mode        = 'multiclass_classification',
        task_level  = 'graph',
        return_type = 'log_probs',            # ← corrected
    ),
)

# 6) Generate an Explanation
explanation = explainer(
    x          = data.x,
    edge_index = data.edge_index,
    edge_attr  = data.edge_attr,
    batch      = torch.zeros(data.x.size(0), dtype=torch.long, device=device)
)

# 7) Grab the learned masks
edge_mask = explanation.edge_mask   # shape [num_edges]
node_mask = explanation.node_mask   # shape [num_nodes] (or None, if you chose only features)

In [None]:
import os
import numpy as np
import pandas as pd

# pull out the raw edge list (as before)
edges = data.edge_index.t().cpu().numpy()    # shape [num_edges, 2]

# build a DataFrame mapping raw indices → importance
df = pd.DataFrame({
    'src_idx':    edges[:, 0],
    'tgt_idx':    edges[:, 1],
    'importance': weights                       # your normalized importance array
})

# # build a DataFrame mapping indices → names
# df = pd.DataFrame({
#     'source':     [channels[i] for i in edges[:, 0]],
#     'target':     [channels[i] for i in edges[:, 1]],
#     'importance': weights                       # your normalized importance array
# })

# sort by importance descending
df = df.sort_values('importance', ascending=False).reset_index(drop=True)

# OPTION A: keep only the very top-30 edges
df_filtered = df.head(30)

# OPTION B: keep only edges above a threshold (e.g. importance ≥ 0.8)
# df_filtered = df[df['importance'] >= 0.8].reset_index(drop=True)

# display the filtered DataFrame
print(df_filtered.to_markdown(index=False))

# --- build the 14×14 matrix ---
num_nodes = 14
mat = np.zeros((num_nodes, num_nodes), dtype=float)

for _, row in df_filtered.iterrows():
    i, j, w = int(row.src_idx), int(row.tgt_idx), row.importance
    mat[i, j] = w
    mat[j, i] = w   # comment out if your graph is directed

# --- write it out to your Google Drive ---
out_path = '/content/gdrive/My Drive/GCN/save/important_edges_top30.edge' # create .edge file and create figure with BrainNet tool in Matlab
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, 'w') as f:
    for row in mat:
        line = ' '.join(f"{val:.6f}" for val in row)
        f.write(line + '\n')

print(f"Wrote 14×14 importance matrix to {out_path}")
