In [None]:
# Source-to-parcel analysis

# Import necessary libraries
from matplotlib.animation import FuncAnimation
# import seaborn as sns  # required for heatmap visualization
import networkx as nx
from scipy.stats import pearsonr
from mne.viz import circular_layout
import pandas as pd
from mne_connectivity.viz import plot_connectivity_circle
import matplotlib.pyplot as plt
import os
import glob
import numpy as npc
import cupy as np  # using gpu acceleration
import cupyx.scipy.fft
import mne
from mne.datasets import fetch_fsaverage
from nilearn import datasets
from nilearn.image import get_data
from scipy.signal import hilbert  # scipy core modified in env, running custom lib
import scipy
import matplotlib
import os.path as op

matplotlib.use('Agg')  # Setting the backend BEFORE importing pyplot


scipy.fft.set_backend(cupyx.scipy.fft)

fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

# The files live in:
subject = "fsaverage"
trans = "fsaverage"  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, "bem", "fsaverage-ico-5-src.fif")
bem = op.join(fs_dir, "bem", "fsaverage-5120-5120-5120-bem-sol.fif")


# Import necessary Python modules
matplotlib.use('Agg')  # disable plotting
mne.viz.set_browser_backend('matplotlib', verbose=None)
mne.set_config('MNE_BROWSER_BACKEND', 'matplotlib')


# defining input and output directory
files_in = '../data/in/subjects/'
files_out = '../data/out/subjects/'


def compute_cross_correlation(data_window):
    """Compute cross-correlation for given data window."""
    # Reshape the data to be 2D

    data_2D = data_window.reshape(data_window.shape[0], -1)
    correlation_matrix = np.corrcoef(data_2D, rowvar=True)
    return correlation_matrix

    # Compute dPLI at the level of regions


def compute_dPLI(data):
    print('Computing dPLI')
    n_regions = data.shape[1]  # Compute for regions
    dPLI_matrix = np.zeros((n_regions, n_regions))
    print(data)
    analytic_signal = hilbert(data)
    phase_data = np.angle(analytic_signal)
    for i in range(n_regions):
        for j in range(n_regions):
            if i != j:
                phase_diff = phase_data[:, i] - phase_data[:, j]
                dPLI_matrix[i, j] = np.abs(
                    np.mean(np.exp(complex(0, 1) * phase_diff)))
    return dPLI_matrix

# dPLI_matrix = compute_dPLI(label_time_courses) --> computing static, fc for the entire dataset


def disparity_filter(G, alpha=0.01):
    disparities = {}
    for i, j, data in G.edges(data=True):
        weight_sum_square = sum(
            [d['weight']**2 for _, _, d in G.edges(i, data=True)])
        disparities[(i, j)] = data['weight']**2 / weight_sum_square

    G_filtered = G.copy()
    for (i, j), disparity in disparities.items():
        if disparity < alpha:
            G_filtered.remove_edge(i, j)
    return G_filtered


def graph_to_matrix(graph, size):
    matrix = np.zeros((size, size))
    for i, j, data in graph.edges(data=True):
        matrix[i, j] = data['weight']
        matrix[j, i] = data['weight']  # Ensure symmetry
    return matrix


def plot_matrix(matrix, title, labels, cmap='viridis'):
    plt.figure(figsize=(10, 10))
    # plt.imsave(matrix, cmap='autumn')
    sns.heatmap(matrix, cmap=cmap, xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.savefig(output_path+title+'.png')


def update(window_number):
    ax.clear()
    current_matrix = graph_to_matrix(nx.convert_matrix.from_numpy_array(
        windowed_cross_correlation_matrices[window_number]), windowed_cross_correlation_matrices[window_number].shape[0])
    sns.heatmap(current_matrix, cmap='viridis',
                xticklabels=ordered_regions, yticklabels=ordered_regions, ax=ax)
    title.set_text(
        f'Thresholded Cross-Correlation Matrix for Window {window_number}')
    return ax


def update(window_number):
    ax.clear()
    current_matrix = threshold_matrix(windowed_dpli_matrices[window_number])
    plot_connectivity_circle(current_matrix, ordered_regions, n_lines=300, node_angles=node_angles,
                             title=f'Thresholded Regional Connectivity using dPLI for Window {window_number}', ax=ax)
    return ax,


def threshold_matrix(matrix):
    G_temp = nx.convert_matrix.from_numpy_array(matrix)
    G_temp_thresholded = disparity_filter(G_temp)

    matrix_thresholded = np.zeros_like(matrix)
    for i, j, data in G_temp_thresholded.edges(data=True):
        matrix_thresholded[i, j] = data['weight']
        matrix_thresholded[j, i] = data['weight']
    return matrix_thresholded


def threshold_graph_by_density(G, density=0.1, directed=False):
    if density < 0 or density > 1:
        raise ValueError("Density value must be between 0 and 1.")
    num_edges_desired = int(G.number_of_edges() * density)
    sorted_edges = sorted(G.edges(data=True), key=lambda x: x[2]['weight'],
                          reverse=True)
    if directed:
        G_thresholded = nx.DiGraph()
    else:
        G_thresholded = nx.Graph()
    G_thresholded.add_edges_from(sorted_edges[:num_edges_desired])
    return G_thresholded


# loading list of subject names from txt file
names = open("./names.txt", "r")
subject_list = names.read().split('\n')
modes = ['EC', 'EO']
# Read the custom montage
montage_path = r"../data/in/MFPRL_UPDATED_V2.sfp"
montage = mne.channels.read_custom_montage(montage_path)



schaefer_atlas = datasets.fetch_atlas_schaefer_2018(n_rois=100)
fs_dir = '../data/in/fsaverage'
fname = os.path.join(fs_dir, "bem", "fsaverage-ico-5-src.fif")
src = mne.read_source_spaces(fname, patch_stats=False, verbose=None)

#need to gen the following
# Participant ID
# Group assignment (might need 
# @Maxine He
#  to remind us one last time about how the numbering relates to the group assignment)}
# Condition (Eyes-open or eyes-closed)
# Modularity
# Small-worldness
# Global Efficiency
# Average clustering coefficient
# Average betweenness centrality

In [None]:
for subject in subject_list:
    for mode in modes:
        print(subject, mode)
        # defining paths for current subject
        input_path = files_in+subject + '/' + mode + '/'
        output_path = files_out + subject + '/' + mode + '/'

        # loading in time course files

        label_time_courses_file = output_path + \
            f"{subject}_label_time_courses.npy"
        label_time_courses = np.load(label_time_courses_file)

        labels = mne.read_labels_from_annot('fsaverage', parc='Schaefer2018_100Parcels_7Networks_order',
                                            subjects_dir=r'../data/in/')

        # Group labels by network
        networks = {}
        for label in labels:
            # Extract network name from label name (assuming format: 'NetworkName_RegionName')
            network_name = label.name.split('_')[0]
            if network_name not in networks:
                networks[network_name] = []
            networks[network_name].append(label)

           # Organize regions by their network affiliations and extract the desired naming convention
        ordered_regions = []
        network_labels = []  # This will store the network each region belongs to

        for label in labels:
            # Extract the desired naming convention "PFCl_1-lh" from the full label name
            parts = label.name.split('_')
            region_name = '_'.join(parts[2:])
            ordered_regions.append(region_name)

            # Extract the network name and store it in network_labels
            network_name = parts[2]
            network_labels.append(network_name)

            # Compute cross-correlation between all pairs of regions across windows
            print('Computing cross corelation')
            # Time-resolved dPLI computation
            sampling_rate = 512  # in Hz
            window_length_seconds = 1
            step_size_seconds = 0.5

            # Total duration in samples
            # Assuming the structure is the same as label_time_courses in Code 2
            num_epochs_per_hemisphere = label_time_courses.shape[0] / 2
            duration_per_epoch = label_time_courses.shape[2] / sampling_rate
            total_duration_samples = int(
                num_epochs_per_hemisphere * duration_per_epoch * sampling_rate)

            window_length_samples = int(window_length_seconds * sampling_rate)
            step_size_samples = int(step_size_seconds * sampling_rate)

            num_windows = int(
                (total_duration_samples - window_length_samples) / step_size_samples) + 1
            windowed_dpli_matrices = []
            windowed_cross_correlation_matrices = []

        print(networks)


            