In [1]:
import torch
from sklearn.model_selection import train_test_split

def filter_classes(dynamic_data, allowed_classes=[0, 1]):
    return [subj for subj in dynamic_data if len(subj) > 0 and subj[0].y.item() in allowed_classes]

def subject_level_stratified_split(dynamic_data, test_size=0.2, random_state=42):
    subject_list = []
    labels = []
    for subj_windows in dynamic_data:
        if len(subj_windows) == 0:
            continue
        label = subj_windows[0].y.item()
        subject_list.append(subj_windows)
        labels.append(label)

    train_list, test_list = train_test_split(
        subject_list,
        test_size=test_size,
        stratify=labels,
        shuffle=True,
        random_state=random_state
    )
    return train_list, test_list

def compute_node_edge_stats(dynamic_data):
    node_features = []
    edge_features = []
    for subject_list in dynamic_data:
        for data_obj in subject_list:
            node_features.append(data_obj.x)
            edge_features.append(data_obj.edge_attr)
    all_node_feats = torch.cat(node_features, dim=0)
    all_edge_feats = torch.cat(edge_features, dim=0)

    node_means = all_node_feats.mean(dim=0)
    node_stds  = all_node_feats.std(dim=0)
    edge_means = all_edge_feats.mean(dim=0)
    edge_stds  = all_edge_feats.std(dim=0)
    return node_means, node_stds, edge_means, edge_stds

def apply_node_edge_standardization(dynamic_data, node_means, node_stds, edge_means, edge_stds):
    for subject_list in dynamic_data:
        for data_obj in subject_list:
            data_obj.x = (data_obj.x - node_means) / (node_stds + 1e-7)
            data_obj.edge_attr = (data_obj.edge_attr - edge_means) / (edge_stds + 1e-7)
    return dynamic_data

def process_band(dynamic_data_band):
    filtered_data = filter_classes(dynamic_data_band, allowed_classes=[0, 1])
    train_data, test_data = subject_level_stratified_split(filtered_data, test_size=0.2)
    node_means, node_stds, edge_means, edge_stds = compute_node_edge_stats(train_data)
    train_data_stand = apply_node_edge_standardization(train_data, node_means, node_stds, edge_means, edge_stds)
    test_data_stand  = apply_node_edge_standardization(test_data,  node_means, node_stds, edge_means, edge_stds)
    dynamic_data_stand = train_data_stand + test_data_stand
    return train_data_stand, test_data_stand, dynamic_data_stand