dataset preparation

In [6]:
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_max_pool
import networkx as nx
import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import dense_diff_pool
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def load_raw_data(dir_data, rat_name): 
    trlPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_trial_info.npy') 
    spkPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_spike_data_binned.npy') 
    lfpPath = os.path.join(dir_data, rat_name, rat_name.lower() + '_lfp_data_sampled.npy')
    
    trial_info = np.load(trlPath) 
    spike_data = np.load(spkPath)
    lfp_data = np.load(lfpPath)
    lfp_data = np.swapaxes(lfp_data, 1, 2) 
    return trial_info, spike_data, lfp_data



trial_info_superchris, spike_data_superchris, lfp_data_superchris = load_raw_data(dir_data="epoch_data", rat_name="superchris")
trial_info_barat, spike_data_barat, lfp_data_barat = load_raw_data(dir_data="epoch_data", rat_name="barat")
trial_info_stella, spike_data_stella, lfp_data_stella = load_raw_data(dir_data="epoch_data", rat_name="stella")
trial_info_mitt, spike_data_mitt, lfp_data_mitt = load_raw_data(dir_data="epoch_data", rat_name="mitt")
trial_info_buchanan, spike_data_buchanan, lfp_data_buchanan = load_raw_data(dir_data="epoch_data", rat_name="buchanan")


# Assuming the files are correctly loaded, the focus will be on processing the data.

# Update count_labels to count occurrences of all unique labels
def count_labels(trial_info, target_col=3):
    labels = trial_info[:, target_col] - 1  # Adjust for target variable
    unique, counts = np.unique(labels, return_counts=True)
    label_counts = dict(zip(unique, counts))
    return label_counts

# Count labels for each rat dataset
counts = {
    "superchris": count_labels(trial_info_superchris),
    "barat": count_labels(trial_info_barat),
    "stella": count_labels(trial_info_stella),
    "mitt": count_labels(trial_info_mitt),
    "buchanan": count_labels(trial_info_buchanan),
}

# Find the minimum count for each label across datasets
min_label_counts = {}
for label in range(5):  # Adjust for labels 0-4
    min_label_counts[label] = min(counts[rat_name].get(label, 0) for rat_name in counts.keys())

print("Minimum label counts across datasets:", min_label_counts)

# Update select_balanced_trials to balance across all labels
def select_balanced_trials(trial_info, spike_data, lfp_data, target_col=3, n_labels=None):
    if n_labels is None:
        n_labels = {label: 10 for label in range(5)}  # Default to 10 trials per label

    labels = trial_info[:, target_col] - 1  # Adjust for target variable
    selected_indices = []

    for label, n_label in n_labels.items():
        # Get indices of trials with the current label
        indices_label = np.where(labels == label)[0]

        # Handle cases where available trials are less than requested
        n_label = min(len(indices_label), n_label)

        # Randomly select required number of trials for the current label
        selected_indices_label = np.random.choice(indices_label, n_label, replace=False)
        selected_indices.extend(selected_indices_label)

    # Sort selected indices
    selected_indices = np.sort(selected_indices)

    # Subset the trial_info, spike_data, and lfp_data based on selected indices
    trial_info_balanced = trial_info[selected_indices]
    spike_data_balanced = spike_data[selected_indices]
    lfp_data_balanced = lfp_data[selected_indices]

    return trial_info_balanced, spike_data_balanced, lfp_data_balanced

# Apply this function to all datasets independently
balanced_datasets = {}
for rat_name, trial_info, spike_data, lfp_data in [
    ("superchris", trial_info_superchris, spike_data_superchris, lfp_data_superchris),
    ("barat", trial_info_barat, spike_data_barat, lfp_data_barat),
    ("stella", trial_info_stella, spike_data_stella, lfp_data_stella),
    ("mitt", trial_info_mitt, spike_data_mitt, lfp_data_mitt),
    ("buchanan", trial_info_buchanan, spike_data_buchanan, lfp_data_buchanan),
]:
    trial_info_balanced, spike_data_balanced, lfp_data_balanced = select_balanced_trials(
        trial_info, spike_data, lfp_data, n_labels=min_label_counts
    )
    balanced_datasets[rat_name] = {
        "trial_info": trial_info_balanced,
        "spike_data": spike_data_balanced,
        "lfp_data": lfp_data_balanced,
    }

# Check the balanced datasets
for rat_name, data in balanced_datasets.items():
    trial_info_balanced = data["trial_info"]
    labels_balanced = trial_info_balanced[:, 3] - 1
    print(f"{rat_name}: Label counts = {dict(zip(*np.unique(labels_balanced, return_counts=True)))}")



Minimum label counts across datasets: {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}
superchris: Label counts = {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}
barat: Label counts = {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}
stella: Label counts = {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}
mitt: Label counts = {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}
buchanan: Label counts = {0: 54, 1: 36, 2: 27, 3: 36, 4: 23}


MGMT+MAGNET

In [11]:
import json
import torch
import numpy as np 
import os
import random 
from torch_geometric.data import Data
from torch import nn
from sklearn.metrics import accuracy_score, confusion_matrix
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from torch_geometric.nn import GCNConv, GATConv, APPNP, TransformerConv
from torch.nn import Linear, Sequential, BatchNorm1d, Dropout
from custom_pooling import global_max_pool
from convlayer import MaGNetConv
from norm_operation import PairNorm, MeanSubtractionNorm
from gpslayer import * 
import warnings
from itertools import product
np.random.seed(1) # mitt
torch.manual_seed(1)
import torch
torch.use_deterministic_algorithms(False)
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
from custom_pooling import global_max_pool
import os
import torch
import optuna
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import DataLoader
from sklearn.metrics import classification_report
from torch_geometric.utils import dense_to_sparse
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv, global_max_pool, dense_diff_pool, global_mean_pool, TransformerConv, BatchNorm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
from custom_pooling import global_max_pool
from torch.utils.data import DataLoader as TorchDataLoader
from sklearn.preprocessing import StandardScaler

def generate_hyperparam_combinations(hyperparameters):
    keys, values = zip(*hyperparameters.items())
    return [dict(zip(keys, v)) for v in product(*values)]


def prepare_gnn_dataset(lfp_data, trial_info, target_col=3, corr_threshold=0.5):
    dataset = []
    for i in range(lfp_data.shape[0]):
        node_features = torch.tensor(lfp_data[i], dtype=torch.float)
        edge_index = create_edge_index_from_correlation(node_features, threshold=corr_threshold)
        y = torch.tensor(trial_info[i, target_col] - 1, dtype=torch.long)
        data = Data(x=node_features, edge_index=edge_index, y=y)
        dataset.append(data)
    return dataset

def compute_loss(model, dataset, edge_index):
    out, feat, emb = [], [], []

    model.eval()
    with torch.no_grad():
        x_data = torch.tensor(dataset.transpose(0, 2, 1), dtype=torch.float, device=device)  
        edge_index = edge_index.to(device)

        for i in range(x_data.shape[0]):  
            batch = torch.zeros(x_data[i].shape[0], dtype=torch.long, device=device)  # Assign nodes to batch
            data = Data(x=x_data[i], edge_index=edge_index, batch=batch)

            output_model = model(data.x, data.edge_index, batch=data.batch)  

            out.append(output_model[0])  # Logits
            feat.append(output_model[1])  

            # ✅ Ensure emb is a tensor, not a tuple
            emb_val = output_model[2]
            if isinstance(emb_val, tuple):
                emb_val = torch.stack(emb_val, dim=0)  # Convert tuple to tensor
            emb.append(emb_val)  

    soft_prob = torch.cat(out, dim=0)  
    feat = torch.stack([f.T for f in feat], dim=0)  

    # ✅ Fix: Ensure emb is stacked properly
    try:
        emb = torch.stack(emb, dim=0)
    except TypeError as e:
        print(f"⚠️ TypeError in stacking emb: {e}")
        print(f"Emb element types: {[type(e) for e in emb]}")
        emb = torch.cat([e.unsqueeze(0) if e.dim() < 3 else e for e in emb], dim=0)

    return [soft_prob, feat, emb]


def model_acc(model, dataset, edge_index, labels):
    model.eval()  # Set model to evaluation mode
    predictions = []

    with torch.no_grad():
        x_data = torch.tensor(dataset.transpose(0, 2, 1), dtype=torch.float, device=device)  
        edge_index = edge_index.to(device)

        for i in range(x_data.shape[0]):  
            batch = torch.zeros(x_data[i].shape[0], dtype=torch.long, device=device)  # Assign nodes to batch
            data = Data(x=x_data[i], edge_index=edge_index, batch=batch)

            # Forward pass
            output_model = model(data.x, data.edge_index, batch=data.batch)[0]  # Extract logits

            # Get predicted class
            pred_label = output_model.argmax(dim=1)
            predictions.append(pred_label[0])

    # Convert predictions to tensor
    pred_tensor = torch.tensor(predictions, dtype=torch.long, device=device)

    # Compute accuracy efficiently
    correct = (pred_tensor == labels).sum().item()
    accuracy = correct / dataset.shape[0]

    return accuracy
   
   

# Compute the weighted loss error rate
def weight_loss(model, dataset, weights, edge_index, training_label):
    pred = []

    # model prediction on graph labels
    for i in range(dataset.shape[0]):
        nodes_fea = [tmp for tmp in np.transpose(dataset[i,:,:])]
        x = torch.tensor(nodes_fea, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index)
        data = data.to(device)
        pred.append(model(data.x, data.edge_index)[0].argmax(dim=1)[0])
    pred = torch.Tensor(pred).to(device)
    
    # compute the weighted error rate for classifier 
    err_rate = (weights.to(device)*(pred != training_label)).sum()/weights.sum()

    return err_rate

# Evaluate the stabilized quality of the classifier 
def quality_update(err_rate):
    # negative logit function
    alpha = torch.log(0.01 + err_rate / (1 - err_rate))
    return torch.max(-alpha, torch.tensor(0.05))

# Update the weights for classifiers 
def weight_update(model, err_rate, dataset, weights, edge_index, training_label):
    alpha = quality_update(err_rate)

    for i in range(dataset.shape[0]):
        nodes_fea = [tmp for tmp in np.transpose(dataset[i,:,:])]
        x = torch.tensor(nodes_fea, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index)
        data = data.to(device)

        # model prediction and weights updating via exponential rule
        curr_pred = torch.Tensor(model(data.x, data.edge_index)[0].argmax(dim=1)[0]).to(device) 
        weights[i] = torch.exp(alpha * (training_label[i] != curr_pred))

    # normalize the weights for all samples
    weights = nn.functional.normalize(weights, p=2, dim=0)
        
    return weights

    # Evaluate the final model performance 
def test_acc_and_confusion_matrix(model):
    predictions = []

    # build evaluation dataset and batch 
    X_test_tensor = test_emb.clone().detach()
    test_dataset = TensorDataset(X_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
    
    # freeze the model parameters and perform inference 
    with torch.no_grad():
        for features in test_loader:
            features = features[0]  # Unpack the single-element tuple
            outputs = model(features)
            predicted = outputs.argmax(dim=1)  # Get the class with the highest probability
            predictions.extend(predicted.tolist())

    # calculate the accuracy
    accuracy = accuracy_score(testing_label.cpu(), predictions)
    
    # calculate the confusion matrix
    cm = confusion_matrix(testing_label.cpu(), predictions)

    return accuracy, cm

class Weaker_First(torch.nn.Module):
    def __init__(self, num_nodes, num_node_features, num_classes):
        super().__init__()
        self.num_nodes = num_nodes  # Store number of nodes
        self.conv0 = nn.Conv1d(num_nodes, num_nodes, 3, 1, padding=1)
        self.conv1 = TransformerConv(num_node_features, 32, heads=4, concat=True, dropout=0.1)
        self.conv2 = TransformerConv(32 * 4, 128, heads=4, concat=True, dropout=0.1)
        self.conv3 = TransformerConv(128 * 4, 32, heads=4, concat=True, dropout=0.1)

        self.local_size = min(7, num_nodes)  # Ensure valid local size
        self.final_feature_dim = 32 * 4  # Ensure correct output size

        self.lr1 = nn.Linear(self.final_feature_dim, 128)
        self.lr2 = nn.Linear(128, 32)
        self.lr3 = nn.Linear(32, num_classes)
        self.attention_scores = None
        
    def forward(self, x, edge_index, batch=None):
        x = self.conv0(x)
        x, attn1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x, attn2 = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x, attn3 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        lap = x  # Store intermediate representation

        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        x = global_max_pool(x, batch)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.lr1(x))
        x = F.relu(self.lr2(x))
        x = self.lr3(x)
        
        self.attention_scores = attn1[1] + attn2[1] + attn3[1]
        return [F.log_softmax(x, dim=1), lap, x]

class Weaker_Middle(torch.nn.Module):
    def __init__(self, input_dim, num_classes, heads=4, dropout=0.1):  
        super().__init__()
        self.input_dim = input_dim  # Dynamically set input feature size
        self.num_classes = num_classes  

        self.conv1 = TransformerConv(128, 32, heads=heads, concat=True, dropout=dropout)  
        
        self.hidden_dim = 32 * heads  # Ensure proper feature size
        self.fc_input_dim = self.hidden_dim  # Update based on actual feature size
        
        self.lr1 = nn.Linear(self.fc_input_dim, 128)
        self.lr2 = nn.Linear(128, 48)
        self.lr3 = nn.Linear(48, self.num_classes)  

        self.attention_scores = None  

    def forward(self, x, edge_index, return_node_embeddings=False, batch=None):
        x, attention1 = self.conv1(x, edge_index, return_attention_weights=True)  
        self.attention_scores = attention1[1]  
        lap = x  
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        # Pooling over nodes
        x = global_max_pool(x, batch)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.lr1(x))
        x = F.relu(self.lr2(x))
        x = self.lr3(x)

        if return_node_embeddings:
            return x  

        return [F.log_softmax(x, dim=1), lap, x]



def normalize_features(train_data, test_data, val_data):
    """
    Normalizes the LFP dataset using StandardScaler, ensuring data type remains unchanged.
    """
    train_features = train_data.reshape(-1, train_data.shape[-1])  # Flatten across samples
    scaler = StandardScaler()
    scaler.fit(train_features)  # Fit only on training data

    # Apply transformation without changing data type
    train_data = scaler.transform(train_data.reshape(-1, train_data.shape[-1])).reshape(train_data.shape)
    test_data = scaler.transform(test_data.reshape(-1, test_data.shape[-1])).reshape(test_data.shape)
    val_data = scaler.transform(val_data.reshape(-1, val_data.shape[-1])).reshape(val_data.shape)

    return train_data, test_data, val_data

device = torch.device("cpu")

def set_seed(seed=1):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

set_seed(1)


class GraphLevelPredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=5, heads=4, dropout=0.1):
        super(GraphLevelPredictor, self).__init__()
        self.conv1 = TransformerConv(input_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = TransformerConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout)

        self.fc = nn.Linear(hidden_dim * heads, output_dim)  

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

        x = global_mean_pool(x, batch)  # Aggregates node embeddings into graph embeddings

        return self.fc(x)


def create_edge_index_from_correlation(data, threshold=0.7):
    """
    Create an edge index based on the correlation matrix of the data.

    Args:
        data (torch.Tensor): A 2D tensor of shape (num_nodes, num_features), where rows are nodes and columns are features.
        threshold (float): Correlation threshold. Only edges with |correlation| > threshold are included.

    Returns:
        torch.Tensor: Edge index tensor of shape [2, num_edges].
    """
    data_np = data.numpy()  # Convert to NumPy for correlation computation
    num_nodes = data_np.shape[0]  # Number of nodes
    correlation_matrix = np.corrcoef(data_np.T)  # Compute correlation between features
    
    # Get indices where |correlation| > threshold
    indices = np.where(np.abs(correlation_matrix) > threshold)
    
    # Remove self-loops (correlations between the same nodes)
    valid_edges = indices[0] != indices[1]
    source_nodes = indices[0][valid_edges]
    target_nodes = indices[1][valid_edges]
    
    # Ensure indices are within valid range
    valid_mask = (source_nodes < num_nodes) & (target_nodes < num_nodes)
    source_nodes = source_nodes[valid_mask]
    target_nodes = target_nodes[valid_mask]
    
    # Create edge index tensor
    edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
    return edge_index


def prepare_gtn_dataset(lfp_data, trial_info, target_col=3, corr_threshold=0.5):
    dataset = []
    for i in range(lfp_data.shape[0]):
        node_features = torch.tensor(lfp_data[i], dtype=torch.float)
        edge_index = create_edge_index_from_correlation(node_features, threshold=corr_threshold)
        y = torch.tensor(trial_info[i, target_col] - 1, dtype=torch.long)
        data = Data(x=node_features, edge_index=edge_index, y=y)
        dataset.append(data)
    return dataset


def set_weights(model):
    for layer in model.modules():
        if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
            torch.manual_seed(1)
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
 
class GTN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, heads=4, dropout=0.1):
        super(GTN, self).__init__()
        self.conv1 = TransformerConv(num_node_features, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = TransformerConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout)
        self.conv3 = TransformerConv(hidden_channels * heads, 32, heads=1, dropout=dropout)

        self.batch_norm = nn.LayerNorm(32)  
        self.fc = nn.Linear(32, num_classes)
        self.dropout = nn.Dropout(p=dropout)
        self.attention_scores = None
        
    def forward(self, data, return_node_embeddings=False):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x, attention1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.dropout(x) if self.training else x
        x, attention2 = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.dropout(x) if self.training else x
        x, attention3 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.batch_norm(x)  

        if return_node_embeddings:
            return x  # Return per-node embeddings

        # Aggregate node embeddings into graph-level representation
        x = global_mean_pool(x, batch)

        self.attention_scores = attention1[1] + attention2[1] + attention3[1]
        
        return self.fc(x)



def train_gtns(loaders, gtn_models, epochs=5, lr=0.002):
    for gtn in gtn_models:
        gtn.to(device)

    optimizers = [torch.optim.Adam(gtn.parameters(), lr=lr) for gtn in gtn_models]
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for loader, gtn, optimizer in zip(loaders, gtn_models, optimizers):
            for batch in loader:
                batch = batch.to(device)
                optimizer.zero_grad()

                # Use graph-level representation
                out = gtn(batch, return_node_embeddings=False)

                target = batch.y.view(-1).long()  # Ensure correct shape
                loss = criterion(out, target)
                
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                
        print(f"GTN Training Epoch {epoch + 1}, Loss: {total_loss / len(loaders):.4f}")


def extract_features(dataset, gtn_model):
    encoded_features = []
    gtn_model.to(device)
    with torch.no_grad():
        for data in dataset:
            data = data.to(device)
            node_embeddings = gtn_model(data, return_node_embeddings=True)  # Extract node-level features
            encoded_features.append(Data(x=node_embeddings, edge_index=data.edge_index, y=data.y))

    return encoded_features

def compute_node_attention_scores(attention_scores, edge_index, num_nodes):
    """
    Compute node-level weights using extracted multi-head attention scores.

    Args:
        attention_scores (torch.Tensor): Multi-head edge attention scores of shape [num_edges, num_heads].
        edge_index (torch.Tensor): Edge indices of shape [2, num_edges].
        num_nodes (int): Number of nodes in the graph.

    Returns:
        torch.Tensor: Node attention weights of shape [num_nodes].
    """

    num_edges = min(edge_index.shape[1], attention_scores.shape[0])  # Ensure valid indexing
    edge_index = edge_index[:, :num_edges]  # Truncate excess edges
    attention_scores = attention_scores[:num_edges]  # Truncate excess scores

    node_scores = torch.zeros(num_nodes, device=attention_scores.device)

    # Aggregate attention scores across heads
    attention_scores = attention_scores.mean(dim=1)  # Shape: [num_edges]

    # Ensure edges don't reference nodes outside valid range
    valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
    edge_index = edge_index[:, valid_mask]  # Remove invalid edges
    attention_scores = attention_scores[valid_mask]  # Adjust attention scores

    # Aggregate attention scores per node
    for i, (src, dst) in enumerate(edge_index.t()):
        node_scores[src] += attention_scores[i].item()
        node_scores[dst] += attention_scores[i].item()

    return node_scores


def construct_subgraph(graph, significant_nodes, edge_weight_threshold=0.5):
    node_indices = significant_nodes.nonzero(as_tuple=True)[0]  # Get selected node indices
    graph.x = graph.x.T
    
    if node_indices.numel() == 0:
        print("⚠️ No significant nodes found. Selecting a fallback node.")
        node_indices = torch.tensor([0])

    max_index = graph.x.size(0) - 1
    if node_indices.max().item() > max_index:
        raise ValueError(
            f"🚨 Error: Selected node index {node_indices.max().item()} exceeds "
            f"maximum valid index {max_index}."
        )

    # Subset **nodes** (rows), keeping **all features** (columns)
    new_x = graph.x[node_indices, :]  # Ensure features remain intact
    # Adjust edge_index to reference the selected nodes
    new_edge_index = adjust_edge_index(graph.edge_index, node_indices)
    subgraph = Data(x=new_x, edge_index=new_edge_index)
    subgraph.y = graph.y  # Keep the original label

    return subgraph


def construct_subgraphs(node_attention_scores, raw_graphs, encoded_graphs, threshold):
    """
    Construct subgraphs based on node attention scores.

    Args:
        node_attention_scores (torch.Tensor): Node attention weights.
        raw_graphs (list): Original graph data.
        encoded_graphs (list): Encoded graph data with transformed features.
        threshold (float): Threshold for selecting important nodes.

    Returns:
        list: List of constructed subgraphs.
    """
    subgraphs = []
    for raw_graph, encoded_graph in zip(raw_graphs, encoded_graphs):
        normalized_scores = normalize_attention_scores(node_attention_scores)
        significant_nodes = normalized_scores >= threshold
        subgraph = construct_subgraph(encoded_graph, significant_nodes)
        subgraphs.append(subgraph)
    
    return subgraphs

def normalize_attention_scores(node_scores):
    min_score = node_scores.min()
    max_score = node_scores.max()
    normalized_scores = (node_scores - min_score) / (max_score - min_score + 1e-5)
    return normalized_scores

def compute_global_similarity_matrix(graphs):

    all_node_embeddings = torch.cat([graph.x for graph in graphs], dim=0)
    similarity_matrix = F.cosine_similarity(
        all_node_embeddings.unsqueeze(1), all_node_embeddings.unsqueeze(0), dim=2
    )
    return similarity_matrix

# Threshold similarity matrix to create edges
def threshold_similarity(similarity_matrix, threshold=0.5):
    edge_indices = (similarity_matrix > threshold).nonzero(as_tuple=False)
    return edge_indices.t().contiguous()

def construct_meta_graph(graphs, similarity_matrix, threshold=0.5):
    combined_node_features = torch.cat([graph.x for graph in graphs], dim=0)
    
    intra_graph_edges = []
    offset = 0

    for graph in graphs:
        intra_edges = graph.edge_index.clone()
        intra_edges = intra_edges + offset  # Apply offset correctly
        intra_graph_edges.append(intra_edges)
        offset += graph.x.shape[0]  # Increment offset properly

    intra_graph_edges = torch.cat(intra_graph_edges, dim=1) if intra_graph_edges else torch.empty((2, 0), dtype=torch.long)

    inter_graph_edges = threshold_similarity(similarity_matrix, threshold)

    # Ensure inter_graph_edges do not exceed node count
    valid_mask = (inter_graph_edges[0] < combined_node_features.shape[0]) & (inter_graph_edges[1] < combined_node_features.shape[0])
    inter_graph_edges = inter_graph_edges[:, valid_mask]  # Fix shape mismatch


    # Ensure indices are valid before concatenation
    if intra_graph_edges.numel() > 0:
        combined_edges = torch.cat([intra_graph_edges, inter_graph_edges], dim=1)
    else:
        combined_edges = inter_graph_edges
    meta_graph = Data(x=combined_node_features, edge_index=combined_edges)
    return meta_graph

def adjust_edge_index(edge_index, node_subset):
    """Remap edge_index so it only contains nodes from node_subset."""
    node_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_subset)}
    valid_edges = [edge for edge in edge_index.t().tolist() if edge[0] in node_map and edge[1] in node_map]
    
    if valid_edges:
        new_edge_index = torch.tensor([[node_map[src], node_map[dst]] for src, dst in valid_edges], dtype=torch.long).t()
        return new_edge_index
    else:
        return torch.empty((2, 0), dtype=torch.long)  # Return empty edge index

def construct_meta_graphs(subgraphs_list):
    meta_graphs = []
    for i in range(len(subgraphs_list[0])):
        similarity_matrix = compute_global_similarity_matrix(
            [subgraphs[i] for subgraphs in subgraphs_list]
        )
        meta_graph = construct_meta_graph(
            [subgraphs[i] for subgraphs in subgraphs_list], similarity_matrix, threshold=0.5
        )
        meta_graph.y = subgraphs_list[0][i].y  # Assign label from one of the subgraphs
        meta_graphs.append(meta_graph)
    return meta_graphs

test_predictions = None   

device = torch.device('cpu')  # Force CPU usage
warnings.filterwarnings('ignore')

import time


def objective(trial):
    global test_predictions
    start_time_total = time.time()

    thres =  trial.suggest_float("thres", 0.05, 1, step=0.05)
    layer_num = trial.suggest_categorical("layer_num", [2,3,4,5,6])
    threshold1 = trial.suggest_float("threshold1", 0.1, 1, step=0.1)
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    hidden_dim = 32
    num_train = trial.suggest_categorical("num_train", [5, 10, 20, 30, 40])
    num_epochs = trial.suggest_categorical("num_epochs", [5, 10, 20, 30, 40])
    dropout1 = trial.suggest_float("dropout1", 0.1, 0.9, step=0.05)  # Dropout probability
    lr = trial.suggest_loguniform("lr", 1e-4, 1e-0)  # Learning rate (log scale for finer tuning)
    train_weak = trial.suggest_categorical("train_weak", [5, 10, 15, 20, 30, 40])
    train_middle = trial.suggest_categorical("train_middle", [5, 10, 15, 20, 30, 40])
    weight_decay_weak = trial.suggest_loguniform("weight_decay_weak", 1e-5, 1e-2)
    lr_weak = trial.suggest_loguniform("lr_weak", 1e-4, 1e-0)
    weight_decay_middle = trial.suggest_loguniform("weight_decay_middle", 1e-5, 1e-2)
    lr_middle = trial.suggest_loguniform("lr_middle", 1e-4, 1e-0)
    batch1 = trial.suggest_categorical("batch1", [8, 16, 32])
    best_accuracy = 0
    best_hyperparams = None
    best_cm = None   
    gnn_datasets = {}
    for rat_name, data in balanced_datasets.items():
        trial_info_balanced = data["trial_info"]
        lfp_data_balanced = data["lfp_data"]
        gnn_datasets[rat_name] = prepare_gnn_dataset(lfp_data_balanced, trial_info_balanced)

    
    dataset1 = gnn_datasets["superchris"]
    dataset2 = gnn_datasets["barat"]
    dataset3 = gnn_datasets["stella"]
    dataset4 = gnn_datasets["mitt"]
    dataset5 = gnn_datasets["buchanan"]  
    
    datasets = [dataset1, dataset2, dataset3, dataset4, dataset5]

    
    from sklearn.model_selection import KFold

    
    testing = {}
    training = {}
    validation = {}
    
    training_labels = {}
    testing_labels = {}
    validation_labels = {}
    
    training_edge_indices = {}
    testing_edge_indices = {}
    validation_edge_indices = {}

    adj_matrices = {}
    edge_indices = {}
    train_dataset = {}
    test_datasett = {}
    validation_datasett = {}
    accuracies = []
    test_predictions_all = []
    test_labels_all = []
    val_accuracies = []

    dataset_size = len(dataset1)  # Ensure we're using the right dataset size
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    
    indices = list(range(dataset_size))  # Ensure indices cover all sample
    for fold, (train_idx, test_idx) in enumerate(kf.split(indices)):
        print(f"Starting Fold {fold + 1} - Time: {time.time() - start_time_total:.4f} seconds")        
        print(f"Fold {fold + 1}:")
        train_idx, validation_idx = train_test_split(train_idx, test_size=0.2, random_state=1)

   
        for i, dataset in enumerate(datasets):
            data_lfp = np.array([data.x.numpy().T for data in dataset])  # Shape: (num_samples, 400, 21)
            data_trial = np.array([data.y.item() for data in dataset])  # Labels
            testing[i] = data_lfp[test_idx]
            training[i] = data_lfp[train_idx]
            validation[i] = data_lfp[validation_idx]
            
            training_labels[i] = data_trial[train_idx]
            testing_labels[i] = data_trial[test_idx]
            validation_labels[i] = data_trial[validation_idx]
            
            training_labels[i] = torch.tensor(training_labels[i], dtype=torch.long)
            testing_labels[i] = torch.tensor(testing_labels[i], dtype=torch.long)
            validation_labels[i] = torch.tensor(validation_labels[i], dtype=torch.long)  # Not used in training
        
            # Preserve corresponding edge_index for each sample
            training_edge_indices[i] = [dataset[j].edge_index for j in train_idx]
            testing_edge_indices[i] = [dataset[j].edge_index for j in test_idx]
            validation_edge_indices[i] = [dataset[j].edge_index for j in validation_idx]  # Not used in training

            
            # Define a sample "data" object to be used in the model initialization
            sample_data = Data(x=training[i][0].T, edge_index=training_edge_indices[i][0])  # Use first training sample
            num_classes = len(torch.unique(torch.tensor(data_trial, dtype=torch.long)))
            train_dataset[i] = []
            for sample in range(training[i].shape[0]):
                edge_index = training_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                train_dataset[i].append(Data(x=training[i][sample].T, edge_index=edge_index))
    
            # Create test and validation datasets in the same format as train_dataset
            test_datasett[i] = []
            for sample in range(testing[i].shape[0]):
                edge_index = testing_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                test_datasett[i].append(Data(x=testing[i][sample].T, edge_index=edge_index))
            
            validation_datasett[i] = []
            for sample in range(validation[i].shape[0]):
                edge_index = validation_edge_indices[i][sample]
            
                if isinstance(edge_index, list):  
                    edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
            
                validation_datasett[i].append(Data(x=validation[i][sample].T, edge_index=edge_index))   
            training_edge_indices[i] = adjust_edge_index(torch.cat(training_edge_indices[i], dim=1), train_idx)
            testing_edge_indices[i] = adjust_edge_index(torch.cat(testing_edge_indices[i], dim=1), test_idx)  
            adj_mat = np.mean([np.corrcoef(data_lfp[sample].T) for sample in range(data_lfp.shape[0])], axis=0)
            adj_mat[adj_mat < thres] = 0
            adj_mat[adj_mat == 1] = 0
            adj_mat[adj_mat > 0] = 1
            
            # Convert to PyTorch tensor
            adj_matrices[i] = torch.tensor(adj_mat, dtype=torch.float)
            print(adj_matrices[i])
            # Create edge_index from adjacency matrix
            edge_indices[i] = adj_matrices[i].nonzero().t().contiguous()
 
            training[i], testing[i], validation[i] = normalize_features(training[i], testing[i], validation[i]) 
           
        models = {i: [] for i in range(len(datasets))} 
        latent_rep = {i: [] for i in range(len(datasets))}
        quality_vec = {i: [] for i in range(len(datasets))}
        final_emb = {}
        test_emb = {}
        validation_emb = {}
    
        torch.manual_seed(1)
        
        start_time = time.time()
        for i, dataset in enumerate(datasets):            
            sample_data = training[i]  # First sample
            num_nodes = sample_data.shape[2]
            num_node_features = sample_data.shape[1]
            model = Weaker_First(num_nodes, num_node_features, num_classes).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr_weak, weight_decay=weight_decay_weak)
            model.train()  
            start_time = time.time()
            for _ in range(train_weak):
                train_loss_compute = compute_loss(model, training[i], edge_indices[i])
                training_softprob = train_loss_compute[0]
                training_softprob = training_softprob.clone().detach().requires_grad_(True)
                optimizer.zero_grad()
                training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                training_loss.backward()
                optimizer.step()
                train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i])
            curr_feat = train_loss_compute[1]
            curr_latent = np.array(train_loss_compute[1])
            print(f"weakfirst time {time.time() - start_time:.4f} seconds")        
            models[i].append(model)
            latent_rep[i].append(curr_latent)
            weights = torch.ones(training[i].shape[0])
            weights = nn.functional.normalize(weights, p=2, dim=0)         
            err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i])
            quality_vec[i].append(quality_update(err_rate))
            start_time = time.time()            
            for _ in range(layer_num - 1):
                model = Weaker_Middle(256, num_classes).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr_middle, weight_decay=weight_decay_middle)
                model.train()
                training[i] = np.array(curr_feat) 
                for epoch in range(train_middle):
                    train_loss_compute = compute_loss(model, training[i], edge_indices[i])
                    training_softprob = train_loss_compute[0]
                    training_softprob = training_softprob.clone().detach().requires_grad_(True)
                    optimizer.zero_grad()
                    training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                    training_loss.backward()
                    optimizer.step()                    
                    train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i])
                    curr_feat = train_loss_compute[1]
                models[i].append(model)
                curr_latent = np.array(train_loss_compute[1])
                latent_rep[i].append(curr_latent)
                err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i])
                quality_vec[i].append(quality_update(err_rate))             
                weights = weight_update(model, err_rate, training[i], weights, edge_indices[i], training_labels[i])                
            final_emb[i] = torch.zeros_like(torch.tensor(latent_rep[i][0])).to(device)
            for j in range(layer_num):
                final_emb[i] += quality_vec[i][j] * torch.tensor(latent_rep[i][j]).to(device)
    
            # Create training dataset
            train_tensor = torch.tensor(training_labels[i], dtype=torch.float32).view(-1, 1)
            dataset = TensorDataset(final_emb[i], train_tensor)
            train_loader = TorchDataLoader(dataset, batch_size=batch1, shuffle=True)        
            test_latent_rep = []         
            for j in range(layer_num):
                model = models[i][j]
                test_loss_compute = compute_loss(model, testing[i], edge_indices[i])
                test_latent_rep.append(np.array(test_loss_compute[1]))
                testing[i] = np.array(test_loss_compute[1])
            test_emb[i] = torch.zeros_like(torch.tensor(test_latent_rep[0])).to(device)          
            for j in range(layer_num):
                test_emb[i] += quality_vec[i][j] * torch.tensor(test_latent_rep[j]).to(device)
            
            # Create test dataset
            test_tensor = torch.tensor(testing_labels[i], dtype=torch.float32).view(-1, 1)
            test_dataset = TensorDataset(test_emb[i], test_tensor)
            test_loader = TorchDataLoader(test_dataset, batch_size=batch1, shuffle=True)
            validation_latent_rep = []
            
            for j in range(layer_num):
                model = models[i][j]
                validation_loss_compute = compute_loss(model, validation[i], edge_indices[i])
                validation_latent_rep.append(np.array(validation_loss_compute[1]))
                validation[i] = np.array(validation_loss_compute[1])
            validation_emb[i] = torch.zeros_like(torch.tensor(validation_latent_rep[0])).to(device)
            for j in range(layer_num):
                validation_emb[i] += quality_vec[i][j] * torch.tensor(validation_latent_rep[j]).to(device)
            
            # Create validation dataset
            validation_tensor = torch.tensor(validation_labels[i], dtype=torch.float32).view(-1, 1)
            validation_dataset = TensorDataset(validation_emb[i], validation_tensor)
            validation_loader = TorchDataLoader(validation_dataset, batch_size=batch1, shuffle=True)
         
        train_subgraphs = []
        val_subgraphs = []
        test_subgraphs = []
        train_subgraphs = []
        test_subgraphs = []
        validation_subgraphs = [] 
        for i in range(len(datasets)):
            encoded_train_data_list = []
            encoded_test_data_list = []
            encoded_validation_data_list = []
        
            # Encode Train Dataset
            for sample in range(len(train_dataset[i])):  # Iterate over Data objects
                node_features = final_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=train_dataset[i][sample].edge_index)
                encoded_train_data_list.append(encoded_data)
        
            # Encode Test Dataset
            for sample in range(len(test_datasett[i])):  # Iterate over Data objects
                node_features = test_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=test_datasett[i][sample].edge_index)
                encoded_test_data_list.append(encoded_data)
        
            # Encode Validation Dataset
            for sample in range(len(validation_datasett[i])):  # Iterate over Data objects
                node_features = validation_emb[i][sample]  # Ensure shape [num_nodes, num_features]
                encoded_data = Data(x=node_features, edge_index=validation_datasett[i][sample].edge_index)
                encoded_validation_data_list.append(encoded_data)

            attention_scores = torch.zeros_like(models[i][0].attention_scores)
            
            # Iterate over each model layer
            for j in range(layer_num):
                # Retrieve the attention scores from the current model
                current_attention_scores = models[i][j].attention_scores
                # Retrieve the quality score for the current model
                quality_score = quality_vec[i][j]                
                # Accumulate the weighted attention scores
                attention_scores += quality_score * current_attention_scores
            node_attention_scores_train = compute_node_attention_scores(
                attention_scores, train_dataset[i][0].edge_index, train_dataset[i][0].x.shape[0]
            )
        
            # Construct subgraphs for training set
            train_subgraphs.append(
                construct_subgraphs(node_attention_scores_train, train_dataset[i], encoded_train_data_list, threshold=threshold1)
            )
        
            # Compute node attention scores from extracted attention weights (Test)
            node_attention_scores_test = compute_node_attention_scores(
                attention_scores, test_datasett[i][0].edge_index, test_datasett[i][0].x.shape[0]
            )
    
            # Construct subgraphs for test set
            test_subgraphs.append(
                construct_subgraphs(node_attention_scores_test, test_datasett[i], encoded_test_data_list, threshold=threshold1)
            )
        
            # Compute node attention scores from extracted attention weights (Validation)
            node_attention_scores_validation = compute_node_attention_scores(
                attention_scores, validation_datasett[i][0].edge_index, validation_datasett[i][0].x.shape[0]
            )
        
            # Construct subgraphs for validation set
            validation_subgraphs.append(
                construct_subgraphs(node_attention_scores_validation, validation_datasett[i], encoded_validation_data_list, threshold=threshold1)
            )    
    
        train_meta_graphs = construct_meta_graphs(train_subgraphs)
        val_meta_graphs = construct_meta_graphs(validation_subgraphs)
        test_meta_graphs = construct_meta_graphs(test_subgraphs)       
        
        # Assign labels to meta-graphs explicitly
        for i, graph in enumerate(train_meta_graphs):
            graph.y = torch.tensor(training_labels[0][i], dtype=torch.long)  # Assign correct labels
        
        for i, graph in enumerate(val_meta_graphs):
            graph.y = torch.tensor(validation_labels[0][i], dtype=torch.long)  # Assign validation labels
        
        for i, graph in enumerate(test_meta_graphs):
            graph.y = torch.tensor(testing_labels[0][i], dtype=torch.long)  # Assign test labels
        


        for graph in train_meta_graphs:
            if graph.y.dim() == 0:  # If the label is scalar
                graph.y = graph.y.unsqueeze(0)  # Convert to tensor of shape [1]
        for graph in val_meta_graphs:
            if graph.y.dim() == 0:
                graph.y = graph.y.unsqueeze(0)
        
        for graph in test_meta_graphs:
            if graph.y.dim() == 0:
                graph.y = graph.y.unsqueeze(0)
        
        from torch_geometric.loader import DataLoader  # Use PyG's DataLoader
        
        train_loader = DataLoader(train_meta_graphs, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_meta_graphs, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_meta_graphs, batch_size=batch_size, shuffle=False)       
        final_model = GraphLevelPredictor(input_dim=128, hidden_dim=hidden_dim, output_dim=5, dropout= dropout1)
    
        final_model.load_state_dict(final_model.state_dict())
        torch.manual_seed(1)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 

    
        set_weights(final_model)
        optimizer = torch.optim.Adam(final_model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        best_val_loss = float("inf")
        best_model_state = None
        best_val_loss = float("inf")
        best_model_state = None
        
        # Initialize accuracy tracking lists
    
        
        # Training loop
        for epoch in range(num_train):
            final_model.train()
            total_loss = 0
            for batch in train_loader:
                optimizer.zero_grad()
                out = final_model(batch)
                loss = criterion(out, batch.y.view(-1).long())
                loss.backward()
                optimizer.step()
                total_loss += loss.item() 
        correct_val, total_val = 0, 0
    
        with torch.no_grad():
            for batch in val_loader:
                out = final_model(batch)
                pred = torch.argmax(out, dim=1)
                correct_val += (pred == batch.y.view(-1).long()).sum().item()
                total_val += batch.y.size(0)
    
        val_accuracy = correct_val / total_val
        val_accuracies.append(val_accuracy)
    
        print(f"Fold {fold + 1} Validation Accuracy: {val_accuracy:.4f}")
    
    # Instead of test accuracy, return average validation accuracy for Optuna
    final_val_accuracy = np.mean(val_accuracies)
    print(f"Final Average Validation Accuracy across folds: {final_val_accuracy:.4f}")
    return final_val_accuracy
# Run Optuna optimization
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1))
study.optimize(objective, n_trials=100)

# Best trial
print("Best trial:")
best_trial = study.best_trial
for key, value in best_trial.params.items():
    print(f"{key}: {value}")



def evaluate_best_model(best_params, gnn_datasets, num_classes=5):
    from sklearn.model_selection import KFold
    from torch_geometric.loader import DataLoader
    print("\n🔍 Starting final evaluation using best hyperparameters...\n")

    # Pick first fold as test set
    dataset_size = len(gnn_datasets["superchris"])
    indices = list(range(dataset_size))
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    train_idx, test_idx = next(kf.split(indices))

    # Prepare full train+val and test datasets (adapt based on your prepare_gnn_dataset logic)
    datasets = [gnn_datasets[name] for name in ["superchris", "barat", "stella", "mitt", "buchanan"]]
    
    training = {}
    testing = {}
    train_dataset = {}
    test_datasett = {}
    training_labels = {}
    testing_labels = {}
    training_edge_indices = {}
    testing_edge_indices = {}
    adj_matrices = {}
    edge_indices = {}

    lr_weak = best_params["lr_weak"]
    weight_decay_weak = best_params["weight_decay_weak"]
    lr_middle = best_params["lr_middle"]
    weight_decay_middle = best_params["weight_decay_middle"]
    train_weak = best_params["train_weak"]
    train_middle = best_params["train_middle"]
    layer_num = best_params["layer_num"]
    threshold1 = best_params["threshold1"]
    batch1 = best_params["batch1"]



    for i, dataset in enumerate(datasets):
        data_lfp = np.array([data.x.numpy().T for data in dataset])  # Shape: (num_samples, 400, 21)
        data_trial = np.array([data.y.item() for data in dataset])  # Labels
        testing[i] = data_lfp[test_idx]
        training[i] = data_lfp[train_idx]
        
        training_labels[i] = data_trial[train_idx]
        testing_labels[i] = data_trial[test_idx]
        
        training_labels[i] = torch.tensor(training_labels[i], dtype=torch.long)
        testing_labels[i] = torch.tensor(testing_labels[i], dtype=torch.long)
        # Preserve corresponding edge_index for each sample
        training_edge_indices[i] = [dataset[j].edge_index for j in train_idx]
        testing_edge_indices[i] = [dataset[j].edge_index for j in test_idx]

        # Define a sample "data" object to be used in the model initialization
        sample_data = Data(x=training[i][0].T, edge_index=training_edge_indices[i][0])  # Use first training sample
        num_classes = len(torch.unique(torch.tensor(data_trial, dtype=torch.long)))
        train_dataset[i] = []
        for sample in range(training[i].shape[0]):
            edge_index = training_edge_indices[i][sample]
        
            if isinstance(edge_index, list):  
                edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
        
            train_dataset[i].append(Data(x=training[i][sample].T, edge_index=edge_index))

        test_datasett[i] = []
        for sample in range(testing[i].shape[0]):
            edge_index = testing_edge_indices[i][sample]
        
            if isinstance(edge_index, list):  
                edge_index = torch.stack(edge_index, dim=1)  # Convert list to tensor
        
            test_datasett[i].append(Data(x=testing[i][sample].T, edge_index=edge_index))

        training_edge_indices[i] = adjust_edge_index(torch.cat(training_edge_indices[i], dim=1), train_idx)
        testing_edge_indices[i] = adjust_edge_index(torch.cat(testing_edge_indices[i], dim=1), test_idx)  
        adj_mat = np.mean([np.corrcoef(data_lfp[sample].T) for sample in range(data_lfp.shape[0])], axis=0)
        adj_mat[adj_mat < thres] = 0
        adj_mat[adj_mat == 1] = 0
        adj_mat[adj_mat > 0] = 1
        
        # Convert to PyTorch tensor
        adj_matrices[i] = torch.tensor(adj_mat, dtype=torch.float)
        
        # Create edge_index from adjacency matrix
        edge_indices[i] = adj_matrices[i].nonzero().t().contiguous()

        training[i], testing[i] = normalize_features(training[i], testing[i]) 
       
    models = {i: [] for i in range(len(datasets))} 
    latent_rep = {i: [] for i in range(len(datasets))}
    quality_vec = {i: [] for i in range(len(datasets))}
    final_emb = {}
    test_emb = {}

    torch.manual_seed(1)
    
    start_time = time.time()
    for i, dataset in enumerate(datasets):            
        sample_data = training[i]  # First sample
        num_nodes = sample_data.shape[2]
        num_node_features = sample_data.shape[1]
        model = Weaker_First(num_nodes, num_node_features, num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr_weak, weight_decay=weight_decay_weak)
        model.train()  
        start_time = time.time()
        for _ in range(train_weak):
            train_loss_compute = compute_loss(model, training[i], edge_indices[i])
            training_softprob = train_loss_compute[0]
            training_softprob = training_softprob.clone().detach().requires_grad_(True)
            optimizer.zero_grad()
            training_loss = F.nll_loss(training_softprob, training_labels[i]) 
            training_loss.backward()
            optimizer.step()
            train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i])
        curr_feat = train_loss_compute[1]
        curr_latent = np.array(train_loss_compute[1])
        print(f"weakfirst time {time.time() - start_time:.4f} seconds")        
        models[i].append(model)
        latent_rep[i].append(curr_latent)
        weights = torch.ones(training[i].shape[0])
        weights = nn.functional.normalize(weights, p=2, dim=0)         
        err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i])
        quality_vec[i].append(quality_update(err_rate))
        start_time = time.time()            
        for _ in range(layer_num - 1):
            model = Weaker_Middle(256, num_classes).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr_middle, weight_decay=weight_decay_middle)
            model.train()
            training[i] = np.array(curr_feat) 
            for epoch in range(train_middle):
                train_loss_compute = compute_loss(model, training[i], edge_indices[i])
                training_softprob = train_loss_compute[0]
                training_softprob = training_softprob.clone().detach().requires_grad_(True)
                optimizer.zero_grad()
                training_loss = F.nll_loss(training_softprob, training_labels[i]) 
                training_loss.backward()
                optimizer.step()                    
                train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i])
                curr_feat = train_loss_compute[1]
            models[i].append(model)
            curr_latent = np.array(train_loss_compute[1])
            latent_rep[i].append(curr_latent)
            err_rate = weight_loss(model, training[i], weights, edge_indices[i], training_labels[i])
            quality_vec[i].append(quality_update(err_rate))             
            weights = weight_update(model, err_rate, training[i], weights, edge_indices[i], training_labels[i])                
        final_emb[i] = torch.zeros_like(torch.tensor(latent_rep[i][0])).to(device)
        for j in range(layer_num):
            final_emb[i] += quality_vec[i][j] * torch.tensor(latent_rep[i][j]).to(device)

        # Create training dataset
        train_tensor = torch.tensor(training_labels[i], dtype=torch.float32).view(-1, 1)
        dataset = TensorDataset(final_emb[i], train_tensor)
        train_loader = TorchDataLoader(dataset, batch_size=batch1, shuffle=True)        
        test_latent_rep = []         
        for j in range(layer_num):
            model = models[i][j]
            test_loss_compute = compute_loss(model, testing[i], edge_indices[i])
            test_latent_rep.append(np.array(test_loss_compute[1]))
            testing[i] = np.array(test_loss_compute[1])
        test_emb[i] = torch.zeros_like(torch.tensor(test_latent_rep[0])).to(device)          
        for j in range(layer_num):
            test_emb[i] += quality_vec[i][j] * torch.tensor(test_latent_rep[j]).to(device)
        
        # Create test dataset
        test_tensor = torch.tensor(testing_labels[i], dtype=torch.float32).view(-1, 1)
        test_dataset = TensorDataset(test_emb[i], test_tensor)
        test_loader = TorchDataLoader(test_dataset, batch_size=batch1, shuffle=True)

    train_subgraphs = []
    test_subgraphs = []
    train_subgraphs = []
    test_subgraphs = []
    for i in range(len(datasets)):
        encoded_train_data_list = []
        encoded_test_data_list = []
    
        # Encode Train Dataset
        for sample in range(len(train_dataset[i])):  # Iterate over Data objects
            node_features = final_emb[i][sample]  # Ensure shape [num_nodes, num_features]
            encoded_data = Data(x=node_features, edge_index=train_dataset[i][sample].edge_index)
            encoded_train_data_list.append(encoded_data)
    
        # Encode Test Dataset
        for sample in range(len(test_datasett[i])):  # Iterate over Data objects
            node_features = test_emb[i][sample]  # Ensure shape [num_nodes, num_features]
            encoded_data = Data(x=node_features, edge_index=test_datasett[i][sample].edge_index)
            encoded_test_data_list.append(encoded_data)


        attention_scores = torch.zeros_like(models[i][0].attention_scores)
        
        # Iterate over each model layer
        for j in range(layer_num):
            # Retrieve the attention scores from the current model
            current_attention_scores = models[i][j].attention_scores
            # Retrieve the quality score for the current model
            quality_score = quality_vec[i][j]                
            # Accumulate the weighted attention scores
            attention_scores += quality_score * current_attention_scores
        node_attention_scores_train = compute_node_attention_scores(
            attention_scores, train_dataset[i][0].edge_index, train_dataset[i][0].x.shape[0]
        )
    
        # Construct subgraphs for training set
        train_subgraphs.append(
            construct_subgraphs(node_attention_scores_train, train_dataset[i], encoded_train_data_list, threshold=threshold1)
        )
    
        # Compute node attention scores from extracted attention weights (Test)
        node_attention_scores_test = compute_node_attention_scores(
            attention_scores, test_datasett[i][0].edge_index, test_datasett[i][0].x.shape[0]
        )

        # Construct subgraphs for test set
        test_subgraphs.append(
            construct_subgraphs(node_attention_scores_test, test_datasett[i], encoded_test_data_list, threshold=threshold1)
        )
    

    train_meta_graphs = construct_meta_graphs(train_subgraphs)
    test_meta_graphs = construct_meta_graphs(test_subgraphs)       
    
    # Assign labels to meta-graphs explicitly
    for i, graph in enumerate(train_meta_graphs):
        graph.y = torch.tensor(training_labels[0][i], dtype=torch.long)  # Assign correct labels
    
    for i, graph in enumerate(test_meta_graphs):
        graph.y = torch.tensor(testing_labels[0][i], dtype=torch.long)  # Assign test labels
    


    for graph in train_meta_graphs:
        if graph.y.dim() == 0:  # If the label is scalar
            graph.y = graph.y.unsqueeze(0)  # Convert to tensor of shape [1]
    
    for graph in test_meta_graphs:
        if graph.y.dim() == 0:
            graph.y = graph.y.unsqueeze(0)
    
    from torch_geometric.loader import DataLoader  # Use PyG's DataLoader

    train_loader = DataLoader(train_meta_graphs, batch_size=best_params["batch_size"], shuffle=True)
    test_loader = DataLoader(test_meta_graphs, batch_size=best_params["batch_size"], shuffle=False)

    # Initialize model
    model = GraphLevelPredictor(
        input_dim=128,
        hidden_dim=32,  # Adjust if tunable
        output_dim=num_classes,
        dropout=best_params["dropout1"]
    ).to(device)

    set_weights(model)
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=best_params["lr"], 
        weight_decay=1e-4
    )
    criterion = nn.CrossEntropyLoss()

    # Train on train+val
    model.train()
    for epoch in range(best_params["num_train"]):
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y.view(-1).long())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    # Final test evaluation
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            out = model(batch)
            preds = torch.argmax(out, dim=1)
            correct += (preds == batch.y.view(-1)).sum().item()
            total += batch.y.size(0)
            all_preds.extend(preds.tolist())
            all_labels.extend(batch.y.view(-1).tolist())

    final_test_accuracy = correct / total
    print(f"\n✅ Final Test Accuracy (on held-out fold): {final_test_accuracy:.4f}")
    return final_test_accuracy, all_preds, all_labels


final_accuracy, preds, labels = evaluate_best_model(best_params, gnn_datasets)

[I 2025-04-21 17:57:46,351] A new study created in memory with name: no-name-779e2fd4-016b-44f0-b2d0-9d6389697342


Starting Fold 1 - Time: 2.5522 seconds
Fold 1:
tensor([[0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1.,
         1., 1., 1.],
        [1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
         0., 0., 0.],
        [1., 1., 1., 1., 1

[W 2025-04-21 18:02:22,425] Trial 0 failed with parameters: {'layer_num': 3, 'threshold1': 0.1, 'batch_size': 32, 'num_train': 40, 'num_epochs': 10, 'dropout1': 0.25, 'lr': 0.15957993164212958, 'train_weak': 5, 'train_middle': 40, 'weight_decay_weak': 0.0003976453011700043, 'lr_weak': 0.05854751355295725, 'weight_decay_middle': 8.841926348917726e-05, 'lr_middle': 0.05571905096939266, 'batch1': 8} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/zmoslemi/.local/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_2276309/1290689301.py", line 810, in objective
    train_acc = model_acc(model, training[i], edge_indices[i], training_labels[i])
  File "/tmp/ipykernel_2276309/1290689301.py", line 118, in model_acc
    output_model = model(data.x, data.edge_index, batch=data.batch)[0]  # Extract logits
  File "/home/zmoslemi/.local/lib/python3.10/site-packa

KeyboardInterrupt: 