In [None]:
from torch_geometric.data import Data
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
import pandas as pd
import torch
import os
import numpy as np
from scipy.signal import hilbert, coherence
import scipy.io
import networkx as nx
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr, kendalltau
import dcor
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
import pywt


def calculate_small_worldness(G):
    if not nx.is_connected(G):
        largest_cc = max(nx.connected_components(G), key=len)
        G = G.subgraph(largest_cc).copy()

    n = G.number_of_nodes()
    m = G.number_of_edges()

    if n < 2:
        return np.nan

    avg_clustering = nx.average_clustering(G)
    avg_path_length = nx.average_shortest_path_length(G)

    expected_clustering = 2 * m / (n * (n - 1))
    expected_path_length = np.log(n) / np.log((2 * m) / n)

    if expected_clustering == 0 or expected_path_length == 0:
        return np.nan

    small_worldness = (avg_clustering / expected_clustering) / (avg_path_length / expected_path_length)
    return small_worldness

def calculate_modularity(G):
    partition = nx.community.greedy_modularity_communities(G)
    return nx.community.modularity(G, partition)

def calculate_global_efficiency(G):
    return nx.global_efficiency(G)

def calculate_avg_clustering_coefficient(G):
    return nx.average_clustering(G)

def calculate_eigenvector_centrality(G):
    try:
        centrality = nx.eigenvector_centrality(G)
        return np.mean(list(centrality.values()))
    except:
        return 0

def calculate_wavelet_coeffs(signal, level=6):
    coeffs = pywt.wavedec(signal, 'db4', level=level)
    return np.concatenate(coeffs)

def calculate_covariance(signal1, signal2):
    return np.cov(signal1, signal2)[0, 1]

def calculate_phase_lag_index(signal1, signal2):
    analytic_signal1 = hilbert(signal1)
    analytic_signal2 = hilbert(signal2)
    instantaneous_phase1 = np.angle(analytic_signal1)
    instantaneous_phase2 = np.angle(analytic_signal2)
    phase_difference = instantaneous_phase1 - instantaneous_phase2
    pli = abs(np.mean(np.sign(np.sin(phase_difference))))
    return pli

def calculate_cross_correlation(signal1, signal2):
    cross_corr = np.correlate(signal1, signal2, mode='full')
    return np.mean(cross_corr)

def calculate_pearson_correlation(x, y):
    return pearsonr(x, y)[0]


input_directory = '/content/drive/MyDrive/removedspikes_nonan_replace_window10avg/ply_dtrnd/dtrnd-tddr-fit-ply-dtrnd/baselinecorr/ready for segmentation/merged_datas_all_normalized(abslute value max each ch)/rest/mci'
output_directory = '/content/drive/MyDrive/removedspikes_nonan_replace_window10avg/ply_dtrnd/dtrnd-tddr-fit-ply-dtrnd/baselinecorr/ready for segmentation/merged_datas_all_normalized(abslute value max each ch)/temporal2_10_rest_mci_swn+graphfeat+hbr/left'

os.makedirs(output_directory, exist_ok=True)

regions = {
    'right': list(range(16)),
    'middle': list(range(16, 32)),
    'left': list(range(32, 48)),
    'all': list(range(48))
}


# channels_of_interest = [2, 3, 4, 6, 7, 9, 14, 25, 30, 31, 33, 36, 39, 40, 47, 48] #16_roi_channel
channels_of_interest =[33,34,35, 36,37,38, 39, 40,41,42,43,44,45,46 ,47, 48] #[1,2, 3, 4,5, 6, 7,8, 9,10,11,12,13, 14,15,16]  # [17,18,19,20,21,22,23,24 ,25,26,27,28,29, 30, 31,32]# # #[33,34,35, 36,37,38, 39, 40,41,42,43,44,45,46 ,47, 48]# # #  ##  ##

#[33,34,35, 36,37,38, 39, 40,41,42,43,44,45,46 ,47, 48]

# channels_of_interest = [2, 3, 4, 6, 7, 9, 14, 25, 30, 31, 33, 36, 39, 40, 47, 48] #16_roi_channel
#channels_of_interest = [1,2, 3, 4,5, 6, 7,8, 9,10,11,12,13, 14,15,16,17,18,19,20,21,22,23,24 ,25,26,27,28,29, 30, 31,32, 33,34,35, 36,37,38, 39, 40,41,42,43,44,45,46 ,47, 48] #16_roi_channel
channels_of_interest = [ch - 1 for ch in channels_of_interest]

# Prepare the sequence of temporal graphs
temporal_graphs = []



# Prepare the sequence of temporal graphs for each patient
for filename in os.listdir(input_directory):
    if filename.endswith('_hbo_denoised_cleaned_cleaned_filtered.mat'):
        base_filename = filename.replace('_hbo_denoised_cleaned_cleaned_filtered.mat', '')
        hbo_file_path = os.path.join(input_directory, filename)
        hbr_file_path = os.path.join(input_directory, f'{base_filename}_hbr_denoised_cleaned_cleaned_filtered.mat')

        if os.path.exists(hbr_file_path):
            hbo_mat = scipy.io.loadmat(hbo_file_path)
            hbo_data = hbo_mat['corrected_data']

            hbr_mat = scipy.io.loadmat(hbr_file_path)
            hbr_data = hbr_mat['corrected_data']

            window_size = int(8.138 * 10)  # 30 seconds window size
            overlap = 0

            start_index = int(8.138 * 5)
            end_index = len(hbo_data) - int(8.138 * 0)
            num_windows = 1 + (end_index - start_index - window_size) // (window_size - overlap)
            pbar = tqdm(total=num_windows, desc=f"Processing {filename}", unit="window")

            temporal_graphs = []

            i = start_index
            while i < end_index:
                edge_index = []
                edge_attr = []
                x = []

                if i + window_size > len(hbo_data):
                    hbo_window_data = hbo_data[-window_size:, :]
                    hbr_window_data = hbr_data[-window_size:, :]
                else:
                    hbo_window_data = hbo_data[i:i + window_size, :]
                    hbr_window_data = hbr_data[i:i + window_size, :]

                for j in channels_of_interest:
                    if np.all(hbo_window_data[:, j] == 0) or np.all(hbr_window_data[:, j] == 0):
                        continue

                    hbo_signal = hbo_window_data[:, j]
                    hbr_signal = hbr_window_data[:, j]

                    # Node features for HbO
                    hbo_max = np.max(hbo_signal)
                    hbo_min = np.min(hbo_signal)
                    hbo_mean = np.mean(hbo_signal)
                    hbo_std = np.std(hbo_signal)
                    hbo_slope = np.polyfit(np.arange(len(hbo_signal)), hbo_signal, 1)[0]
                    hbo_wavelet = calculate_wavelet_coeffs(hbo_signal)
                    hbo_wavelet_mean = np.mean(hbo_wavelet)

                    # Node features for HbR
                    hbr_max = np.max(hbr_signal)
                    hbr_min = np.min(hbr_signal)
                    hbr_mean = np.mean(hbr_signal)
                    hbr_std = np.std(hbr_signal)
                    hbr_slope = np.polyfit(np.arange(len(hbr_signal)), hbr_signal, 1)[0]
                    hbr_wavelet = calculate_wavelet_coeffs(hbr_signal)
                    hbr_wavelet_mean = np.mean(hbr_wavelet)

                    node_features = [
                        hbo_max, hbo_min, hbo_mean, hbo_std, hbo_slope, hbo_wavelet_mean,
                        hbr_max, hbr_min, hbr_mean, hbr_std, hbr_slope, hbr_wavelet_mean
                    ]

                    x.append(node_features)

                G = nx.Graph()

                for m in range(len(channels_of_interest)):
                    hbo_signal = hbo_window_data[:, m]
                    hbr_signal = hbr_window_data[:, m]
                    covariance = calculate_covariance(hbo_signal, hbr_signal)
                    pli = calculate_phase_lag_index(hbo_signal, hbr_signal)
                    pearson_corr = calculate_pearson_correlation(hbo_signal, hbr_signal)
                    cross_corr = calculate_cross_correlation(hbo_signal, hbr_signal)

                    edge_features = [covariance, pli, pearson_corr, cross_corr]

                    edge_index.append([m, m])
                    edge_attr.append(edge_features)
                    G.add_edge(m, m, weight=pearson_corr)

                    for n in range(m + 1, len(channels_of_interest)):
                        if not np.all(hbo_window_data[:, m] == 0) and not np.all(hbo_window_data[:, n] == 0) and \
                           not np.all(hbr_window_data[:, m] == 0) and not np.all(hbr_window_data[:, n] == 0):
                            pearson_corr = calculate_pearson_correlation(hbo_window_data[:, m], hbo_window_data[:, n])

                            if abs(pearson_corr) > 0.60:  # Edge threshold check
                                covariance = calculate_covariance(hbo_window_data[:, m], hbo_window_data[:, n])
                                pli = calculate_phase_lag_index(hbo_window_data[:, m], hbo_window_data[:, n])
                                cross_corr = calculate_cross_correlation(hbo_window_data[:, m], hbo_window_data[:, n])

                                edge_features = [covariance, pli, pearson_corr, cross_corr]

                                edge_index.append([m, n])
                                edge_index.append([n, m])
                                edge_attr.append(edge_features)
                                G.add_edge(m, n, weight=pearson_corr)
                                G.add_edge(n, m, weight=pearson_corr)
                small_worldness = calculate_small_worldness(G)
                modularity = calculate_modularity(G)
                global_efficiency = calculate_global_efficiency(G)
                avg_clustering_coeff = calculate_avg_clustering_coefficient(G)
                eigenvector_centrality = calculate_eigenvector_centrality(G)

                graph_features = [
                    small_worldness if not np.isnan(small_worldness) else -1,
                    modularity,
                    global_efficiency,
                    avg_clustering_coeff,
                    eigenvector_centrality
                ]

                x = torch.tensor(x, dtype=torch.float)
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_attr, dtype=torch.float)
                graph_features = torch.tensor(graph_features, dtype=torch.float)

                # Label each snapshot with 0
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.tensor([1]))
                data.graph_features = graph_features

                temporal_graphs.append(data)

                pbar.update(1)
                i += window_size - overlap

            pbar.close()

            # Save the sequence of temporal graphs for each patient
            temporal_data = StaticGraphTemporalSignal(
                edge_index=[g.edge_index for g in temporal_graphs],
                edge_weight=[g.edge_attr for g in temporal_graphs],
                features=[g.x for g in temporal_graphs],
                targets=[g.y for g in temporal_graphs]  # Adjust the target as per your requirement
            )

            output_file_path = os.path.join(output_directory, f'{base_failename}_temporal_graphs.pt')
            torch.save(temporal_data, output_file_path)

print("Temporal graph data has been saved as .pt files for each patient.")

