In [1]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
from nltools.data import Brain_Data, Design_Matrix, Adjacency
import networkx as nx
from scipy import signal
from mne_connectivity import spectral_connectivity_epochs
import seaborn as sns
import pandas as pd


In [None]:
raw=mne.io.read_raw_fif('/home/pablo/works/dev_thesis_SEEG/data/pte_6_cleaned.fif', preload=True)

In [3]:
data=raw.get_data()
channels=raw.ch_names

In [None]:
raw.plot_psd(fmax=50, average=True, spatial_colors=False)
plt.show()

In [None]:
# Example usage:
# # Load your raw data
# raw = mne.io.read_raw_fif('your_raw_data.fif', preload=True)

# # Define events and epochs
# events = mne.find_events(raw)
# event_id = {'event_name': 1}  # Modify according to your event
# epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.5, baseline=(None, 0), preload=True)

In [50]:
#BAnds of interest
bands = {'theta':(3.5, 7.5),
         'alpha': (7.5, 13),
         'beta': (13, 30),
         'gamma': (30, 45)}

In [None]:
epochs=mne.make_fixed_length_epochs(raw, duration=120, preload=True)
epochs.plot_image(picks=['OF1'],cmap='RdBu_r')

# Group channels based on their prefixes
channel_groups = {}
for idx, channel in enumerate(epochs.info['ch_names']):
  prefix = ''.join(filter(str.isalpha, channel))
  if prefix in channel_groups:
    channel_groups[prefix].append(idx)
  else:
    channel_groups[prefix] = [idx]

for prefix, channels in channel_groups.items():
    print(f'{prefix}: {channels}')

In [6]:
#Lets calculate the GC for each band
#Lets begin with theta and let A and B be CC and CA
#A=channel_groups['CC']
#B=channel_groups['CA']
#gc_ab,gc_ba,freqs=calculate_and_plot_granger_causality(epochs, A, B,  fmin=bands['theta'][0], fmax=bands['theta'][1])

In [7]:
#Now, we are going to calculate the cum sum of the values under the y=0 line and over the y=0 line. 
#When a value is over the y=0 line, it means that the channel in A is causing the channel in B
#When a value is under the y=0 line, it means that the channel in B is causing the channel in A
#res=gc_ab.get_data()[0]-gc_ba.get_data()[0]
#A1=0
#A2=0
#for i in range(len(res)):
#    if res[i]>0:
#        A1+=res[i]
#    else:
#        A2+=res[i]
#print(f'The sum of the values over the y=0 line is {A1}')
#print(f'The sum of the values under the y=0 line is {A2}')
        

In [None]:

def calculate_and_plot_granger_causality(epochs, signals_a, signals_b,verbose=True, fmin=5, fmax=30, gc_n_lags=20,plot=True):
    indices_ab = (np.array([signals_a]), np.array([signals_b]))  # A => B
    indices_ba = (np.array([signals_b]), np.array([signals_a]))  # B => A


    gc_ab = spectral_connectivity_epochs(
        epochs,
        method=["gc"],
        indices=indices_ab,
        fmin=fmin,
        fmax=fmax,
        rank=(np.array([5]), np.array([5])),
        gc_n_lags=gc_n_lags,
        verbose=verbose,
    )  # A => B

    gc_ba = spectral_connectivity_epochs(
        epochs,
        method=["gc"],
        indices=indices_ba,
        fmin=fmin,
        fmax=fmax,
        rank=(np.array([5]), np.array([5])),
        gc_n_lags=gc_n_lags,
        verbose=verbose,
        )  # B => A

    freqs = gc_ab.freqs

    # Plot GC: [A => B]
    if plot == True:
        fig, axis = plt.subplots(1, 1)
        axis.plot(freqs, gc_ab.get_data()[0], linewidth=2, label='A => B')
        axis.set_xlabel("Frequency (Hz)")
        axis.set_ylabel("Connectivity (A.U.)")
        fig.suptitle("GC: [A => B] and [B => A]")

        # Plot GC: [B => A]
        axis.plot(freqs, gc_ba.get_data()[0], linewidth=2, label='B => A')
        axis.legend()
        plt.show()

        # Plot Net GC: [A => B] - [B => A]
        net_gc = gc_ab.get_data() - gc_ba.get_data()  # [A => B] - [B => A]
        fig, axis = plt.subplots(1, 1)
        axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle="--", color="k")
        axis.plot(freqs, net_gc[0], linewidth=2)
        axis.set_xlabel("Frequency (Hz)")
        axis.set_ylabel("Connectivity (A.U.)")
        fig.suptitle("Net GC: [A => B] - [B => A]")
        plt.show()

    return gc_ab, gc_ba, freqs


In [8]:
#MAke a function to calculate the area of two set of channels in a band
def calculate_Area(A,B,epochs,band):
    gc_ab,gc_ba,freqs=calculate_and_plot_granger_causality(epochs, A, B,  fmin=band[0], fmax=band[1],verbose=False,plot=False)
    res=gc_ab.get_data()[0]-gc_ba.get_data()[0]
    A1=0
    A2=0
    for i in range(len(res)):
        if res[i]>0:
            A1+=res[i]
        else:
            A2+=res[i]
    return A1,A2

In [None]:
#Lets calculate the area for all the bands and put them in a matrix. The matrix has the shape of the number of prefixes
#by the number of prefixes
matrix=np.zeros((len(channel_groups.keys()),len(channel_groups.keys()),4))

for i,band in enumerate(bands.keys()):
    for j,prefix1 in enumerate(channel_groups.keys()):
        for k,prefix2 in enumerate(channel_groups.keys()):
            if j==k:
                matrix[j,k,i]=0
                continue
            A=channel_groups[prefix1]
            B=channel_groups[prefix2]
            A1,A2=calculate_Area(A,B,epochs,bands[band])
            matrix[j,k,i]=A1
            matrix[k,j,i]=A2
            print(f'The area of {prefix1} causing {prefix2} in band {band} is {A1}')


#Now we are going to plot the matrix
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Granger Causality Area')
for i,band in enumerate(bands.keys()):
    sns.heatmap(matrix[:,:,i], ax=axs[i//2,i%2], xticklabels=channel_groups.keys(), yticklabels=channel_groups.keys(), cmap='coolwarm')
    axs[i//2,i%2].set_title(band)
plt.show()


In [15]:
#Save the matrix in a csv file
np.savetxt('GC_Area_theta.csv', matrix[:,:,0], delimiter=',')
np.savetxt('GC_Area_alpha.csv', matrix[:,:,1], delimiter=',')
np.savetxt('GC_Area_beta.csv', matrix[:,:,2], delimiter=',')
np.savetxt('GC_Area_gamma.csv', matrix[:,:,3], delimiter=',')


In [5]:
matrix_theta=np.loadtxt('/teamspace/studios/this_studio/GC_Area_alpha.csv', delimiter=',')
matrix_alpha=np.loadtxt('/teamspace/studios/this_studio/GC_Area_alpha.csv', delimiter=',')
matrix_beta=np.loadtxt('/teamspace/studios/this_studio/GC_Area_beta.csv', delimiter=',')
matrix_gamma=np.loadtxt('/teamspace/studios/this_studio/GC_Area_gamma.csv', delimiter=',')

In [None]:
node_names=list(channel_groups.keys())

# Convert numpy matrix to a NetworkX graph
G = nx.DiGraph(matrix_theta)

# Normalize edge weights
edge_weights = np.array([G[u][v]['weight'] for u,v in G.edges()])
max_weight = np.max(edge_weights)
min_weight = np.min(edge_weights)
normalized_weights = (edge_weights - min_weight) / (max_weight - min_weight)

# Draw the graph with improved layout
pos = nx.circular_layout(G)  # Positions for all nodes

# Draw nodes
nx.draw_networkx_nodes(G, pos, node_size=700)

# Draw edges with normalized thickness
edge_labels = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edges(G, pos, arrows=True, width=normalized_weights)



# Draw labels
node_labels = {i: node_names[i] for i in range(len(node_names))}
nx.draw_networkx_labels(G, pos, labels=node_labels)


# Display the graph
plt.title("Directed Graph with Node Labels")
plt.axis("off")
plt.show()


In [None]:
#Mak3 the sum over every row, and plot a bar plot only for theta
sums=np.sum(matrix_theta,axis=0)
plt.bar(channel_groups.keys(),sums)
plt.show()

#Interpreta

In [None]:
#Mak3 the sum over every row, and plot a bar plot
sums=np.sum(matrix, axis=1)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Granger Causality Area Sum')
for i,band in enumerate(bands.keys()):
    axs[i//2,i%2].bar(channel_groups.keys(),sums[:,i])
    axs[i//2,i%2].set_title(band)
plt.show()