In [None]:
import os
import re
import numpy as np
import torch
import networkx as nx
from torch_geometric.data import Data
from tqdm import tqdm
import warnings
from numpy.linalg import svd


warnings.filterwarnings("ignore", category=FutureWarning)

from nilearn.connectome import ConnectivityMeasure
from scipy.stats import kurtosis, skew, entropy
from scipy.signal import welch, csd, hilbert
from scipy.integrate import simpson  # Changed from simps to simpson
from statsmodels.tsa.stattools import grangercausalitytests
from joblib import Parallel, delayed

def get_label(s):
    if 1 <= s <= 36:
        return 2
    elif 37 <= s <= 65:
        return 0
    elif 66 <= s <= 88:
        return 1
    else:
        return None


def svd_entropy(sig, order=10, delay=1):
    """
    Calculate SVD entropy of a 1D time series signal.
    """
    N = len(sig)
    if N < (order - 1) * delay + 1:
        return 0.0
    try:
        X = np.array([sig[i: N - (order - 1) * delay + i: delay] for i in range(order)]).T
        W = svd(X, compute_uv=False)
        W /= W.sum()
        return entropy(W, base=2)
    except:
        return 0.0


def hjorth_parameters(sig):
    """Compute Hjorth parameters: activity, mobility, and complexity"""
    first_deriv = np.diff(sig)
    second_deriv = np.diff(first_deriv)

    var_zero = np.var(sig)
    var_d1 = np.var(first_deriv)
    var_d2 = np.var(second_deriv)

    activity = var_zero
    mobility = np.sqrt(var_d1 / var_zero) if var_zero != 0 else 0
    complexity = np.sqrt(var_d2 / var_d1) / mobility if var_d1 != 0 and mobility != 0 else 0
    return activity, mobility, complexity

def spectral_features(sig, fs=500):
    f, psd = welch(sig, fs=fs)
    mean_psd = np.mean(psd)
    # Fixed simpson call - using correct parameter syntax
    total_power = simpson(y=psd, x=f)
    rel_power = total_power / (len(f) * np.max(psd)) if np.max(psd) > 0 else 0
    psd_norm = psd / np.sum(psd) if np.sum(psd) > 0 else np.zeros_like(psd)
    spec_entropy = entropy(psd_norm, base=2) if np.any(psd_norm > 0) else 0
    return mean_psd, rel_power, spec_entropy    

def shannon_entropy(sig, bins=4):
    h, _ = np.histogram(sig, bins=bins)
    if np.sum(h) == 0:
        return 0.0
    pd = h / np.sum(h)
    pd = pd[pd > 0]
    return entropy(pd, base=2) if len(pd) > 0 else 0.0

def compute_node_signal_features(arr, fs=500):
    n_channels = arr.shape[1]
    feats = []
    for ch in range(n_channels):
        s = arr[:, ch]

        mn = s.min()
        mx = s.max()
        ku = kurtosis(s)
        sk_ = skew(s)
        m = s.mean()
        st = s.std()
        sl = np.polyfit(np.arange(len(s)), s, 1)[0]
        sh = shannon_entropy(s, bins=4)

        mean_psd, rel_power, spec_entropy = spectral_features(s, fs)
        activity, mobility, complexity = hjorth_parameters(s)
        rms = np.sqrt(np.mean(s ** 2))
        zc = ((s[:-1] * s[1:]) < 0).sum()
        svd_ent = svd_entropy(s, order=10, delay=1)

        feats.append([
            mn, mx, ku, sk_, m, st, sl, sh,
            mean_psd, rel_power, spec_entropy,
            activity, mobility, complexity,
            rms, zc, svd_ent
        ])
    return np.array(feats)


def compute_node_hubscore(adj_matrix):
    # Build undirected Nx.Graph from 'adj_matrix'
    G = nx.from_numpy_array(adj_matrix, create_using=nx.Graph())
    try:
        hub_dict, _ = nx.hits(G, max_iter=1000, tol=1e-8, normalized=True)
    except nx.PowerIterationFailedConvergence:
        hub_dict = {n: 0.0 for n in G.nodes()}
    hub_vals = np.array([hub_dict[n] for n in G.nodes()], dtype=float)
    return hub_vals  # shape [n_channels]

def standard_coherence(x, y, fs=500):
    f, Pxx = welch(x, fs=fs)
    _, Pyy = welch(y, fs=fs)
    _, Pxy = csd(x, y, fs=fs)
    cxy = np.abs(Pxy)**2 / ((Pxx * Pyy) + 1e-12)
    return np.mean(cxy)

def compute_granger_matrix(data, max_lag=1):
    """
    Compute Granger causality matrix for all channel pairs
    Returns a matrix of p-values (lower = more significant causality)
    """
    n_channels = data.shape[1]
    granger_matrix = np.zeros((n_channels, n_channels))
    
    for i in range(n_channels):
        for j in range(n_channels):
            if i != j:  # Skip self-loops
                x = data[:, i]
                y = data[:, j]
                # Test if channel i Granger-causes channel j
                arr = np.column_stack([y, x])  # y is affected by x
                try:
                    result = grangercausalitytests(arr, maxlag=max_lag, verbose=False)
                    # Use 1 - p-value so higher values = stronger causality
                    granger_val = 1.0 - result[max_lag][0]['params_ftest'][1]
                    granger_matrix[i, j] = granger_val
                except:
                    granger_matrix[i, j] = 0.0
    
    return granger_matrix

def single_lag_granger(x, y, max_lag=1):
    arr = np.column_stack([y, x])  # Test if x Granger-causes y
    try:
        out = grangercausalitytests(arr, maxlag=max_lag, verbose=False)
        return 1.0 - out[max_lag][0]['params_ftest'][1]
    except:
        return 0.0

def compute_plv_phase_lag(u, v, hilbert_cache):
    ax = hilbert_cache[u]
    ay = hilbert_cache[v]
    px = np.angle(ax)
    py = np.angle(ay)
    diff = px - py
    plv_val  = np.abs(np.mean(np.exp(1j * diff)))
    phase_lg = np.mean(np.abs(diff))
    return plv_val, phase_lg

def compute_edge_feats(u, v, window_data, hilbert_cache, granger_matrix, partial_corr, fs=500):
    x = window_data[:, u]
    y = window_data[:, v]
    plv_val, phase_lg = compute_plv_phase_lag(u, v, hilbert_cache)
    coh_val   = standard_coherence(x, y, fs=fs)
    cross_corr= np.corrcoef(x, y)[0,1]
    gc_val    = granger_matrix[u, v]  # Use the precomputed Granger causality value
    pcorr_val = partial_corr[u, v]
    return [phase_lg, coh_val, cross_corr, plv_val, gc_val, pcorr_val]

def parallel_edge_features(edges, window_data, hilbert_cache, granger_matrix, partial_corr, fs=500, n_jobs=-1):
    results = Parallel(n_jobs=n_jobs)(
        delayed(compute_edge_feats)(u, v, window_data, hilbert_cache, granger_matrix, partial_corr, fs)
        for (u, v) in edges
    )
    return np.array(results)

def make_graph_sparse(G, threshold):
    to_remove = []
    for u, v in G.edges():
        w = abs(G[u][v]['weight'])
        if w < threshold:
            to_remove.append((u, v))
    G.remove_edges_from(to_remove)
    return G

def main():
    INPUT_DIR = r"C:\Users\fathi\Desktop\Rest_eeg_ds004504-download\derivatives\500hz_bands_final\beta\numpy"
    SAVE_DIR   = r"C:\Users\fathi\Desktop\Rest_eeg_ds004504-download\derivatives\500hz_bands_final\beta\numpy\granger_temporal_v1_p95_OVERLAP50"
    os.makedirs(SAVE_DIR, exist_ok=True)

    FS = 500
    WINDOW_SEC = 4
    WINDOW_LEN = FS * WINDOW_SEC
    PERCENTILE = 95
    MAX_LAG = 1  # Maximum lag for Granger causality tests

    dynamic_data = []

    from tqdm import tqdm
    all_files = [
        f for f in os.listdir(INPUT_DIR)
        if f.endswith(".npy") and f.startswith("sub-")
    ]

    for fn in tqdm(all_files, desc="Extracting graphs"):
        stem, _ = os.path.splitext(fn)
        match = re.match(r"sub-(\d+)", stem)
        if not match:
            continue
        subj_id = int(match.group(1))

        label_val = get_label(subj_id)
        if label_val is None:
            continue
        label_tensor = torch.tensor([label_val], dtype=torch.long)

        eeg_path = os.path.join(INPUT_DIR, fn)
        time_series = np.load(eeg_path)
        if time_series.shape[0] < time_series.shape[1]:
            time_series = time_series.T

        subject_data = []
        time_len, n_channels = time_series.shape

        stride = int(WINDOW_LEN * 0.5)  # 90% overlap


        for start_idx in range(0, time_len - WINDOW_LEN + 1, stride):

            window_data = time_series[start_idx : start_idx + WINDOW_LEN, :]

            # 1) Compute Granger causality matrix
            granger_matrix = compute_granger_matrix(window_data, max_lag=MAX_LAG)
            
            # No self-loops in Granger matrix
            np.fill_diagonal(granger_matrix, 0.0)
            
            # 2) Build directed graph from Granger causality matrix
            G = nx.from_numpy_array(granger_matrix, create_using=nx.DiGraph())
            
            # 3) Threshold to keep only significant causality
            edgeweight = [G[u][v]['weight'] for u, v in G.edges()]
            thr_val = np.percentile(edgeweight, PERCENTILE) if len(edgeweight) > 0 else 0.0
            sparse_G = make_graph_sparse(G, thr_val)
            
            edges = list(sparse_G.edges())  # Directed edges (u,v) where u Granger-causes v
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            
            # Compute partial correlation for edge features
            measure = ConnectivityMeasure(kind='partial correlation')
            pcorr = measure.fit_transform([window_data])[0]
            np.fill_diagonal(pcorr, 0.0)
            
            # Adjacency for hubscore (convert to undirected for hub calculation)
            nodelist = sorted(sparse_G.nodes())
            adjacency_thresh = nx.to_numpy_array(sparse_G, nodelist=nodelist)
            
            # Node features
            node_feats_signal = compute_node_signal_features(window_data, FS)
            hub_vals = compute_node_hubscore(adjacency_thresh)  # Need undirected graph for HITS
            combined_node_feats = np.hstack([node_feats_signal, hub_vals.reshape(-1,1)])
            x_torch = torch.tensor(combined_node_feats, dtype=torch.float)
            
            # Edge features
            hilb_cache = [hilbert(window_data[:, ch]) for ch in range(n_channels)]
            ef_array = parallel_edge_features(edges, window_data, hilb_cache, granger_matrix, pcorr, fs=FS, n_jobs=-1)
            ef_torch = torch.tensor(ef_array, dtype=torch.float)
            
            # Build Data object
            data_obj = Data(
                x=x_torch,
                edge_index=edge_index,
                edge_attr=ef_torch,
                y=label_tensor
            )
            subject_data.append(data_obj)
            
        dynamic_data.append(subject_data)
        
        for w_i, d_obj in enumerate(subject_data):
            out_name = f"{stem}_win{w_i}.pt"
            out_path = os.path.join(SAVE_DIR, out_name)
            torch.save(d_obj, out_path)
    
    torch.save(dynamic_data, "granger_all_dynamic_data_ol50_alpha.pt")
    print(f"\nAll dynamic graphs extracted. dynamic_data has {len(dynamic_data)} subjects.")
    print(f"Also saved each window .pt in {SAVE_DIR}")
    print("Saved entire memory object to granger_all_dynamic_data.pt.")

if __name__ == "__main__":
    main()

# Node features (19 total):
# 1. min, 2. max, 3. kurtosis, 4. skewness, 5. mean, 6. std, 7. slope, 8. shannon_entropy
# 9. mean_psd, 10. relative_power, 11. spectral_entropy
# 12. activity, 13. mobility, 14. complexity
# 15. rms, 16. zero_crossings, 17. hubscore

# Edge features (6 total):
# 1. phase_lag, 2. coherence, 3. cross_corr, 4. plv, 5. granger_causality, 6. partial_corr

Extracting graphs:   0%|          | 0/88 [00:00<?, ?it/s]