In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
from nltools.data import Brain_Data, Design_Matrix, Adjacency
import networkx as nx
from scipy import signal
from mne_connectivity import spectral_connectivity_epochs
import seaborn as sns

# Important functions

In [None]:
def calculate_ER(data, window_size, overlap):
    """
    Calculate time-varying Energy Ratio (ER) from Theta(w) using a sliding window.

    Parameters:
    - S: Periodogram estimate for Theta(w)
    - f: Frequency vector
    - window_size: Size of the sliding window in samples
    - overlap: Overlap between consecutive windows in samples

    Returns:
    - time_points: Array of time points corresponding to the center of each window
    - ER_values: Array of time-varying ER values
    """

    # Initialize empty arrays to store results
    time_points = []
    ER_values = []

    # Iterate through the signal with the sliding window
    for start in range(0, len(data) - window_size + 1, overlap):
        end = start + window_size
        #(f, S)= signal.welch(data[start:end], fs=raw.info['sfreq'], nperseg=1024*5)
        (f,S)=signal.periodogram(data[start:end],fs=raw.info['sfreq'],scaling='density')

        # Calculate energy in each frequency band for the current window
        ETheta = np.sum(S[np.where((f >= 3.5) & (f < 7.4))])
        EAlpha = np.sum(S[np.where((f >= 7.4) & (f < 12.4))])
        EBeta = np.sum(S[np.where((f >= 12.4) & (f < 24))])
        EGamma = np.sum(S[np.where((f >= 24) & (f <= 97))])

        # Calculate Energy Ratio (ER) for the current window
        ER = (EBeta + EGamma) / (ETheta + EAlpha)

        # Store results
        time_points.append((start + end) / 2)  # Use the center of the window as the time point
        ER_values.append(ER)

    return np.array(time_points), np.array(ER_values)


# Reading and preprocessing

In [None]:
raw = mne.io.read_raw_nihon('/home/pablo/Documents/Universidad Data/Maestría en Matemáticas Aplicadas/Tesis/data/FA330022.EEG', preload=True)
raw.pick_types(eeg=True, bio=False, misc=False)

#Dropping channels
channels = raw.ch_names
channels_to_remove=['E']
raw.drop_channels(channels_to_remove)

#Filter
raw.filter(l_freq = 0, h_freq = 97.0)
# Set the frequency you want to remove; it's commonly 50 Hz or 60 Hz
notch_freq = 60  # or 60 for the USA and other countries using 60Hz
# Apply notch filter
raw.notch_filter(freqs = notch_freq)


# Getting EI

In [None]:
#Making a matrix of U_n, where the rows are every channel and the columns are the windows
#Getting the channels
channels = raw.ch_names
#Getting the number of channels
n_channels=len(channels)
#Getting the number of windows
n_windows=len(U_n)
#Making a matrix of zeros
U_n_matrix=np.zeros((n_channels,n_windows))
ER_matrix=np.zeros((n_channels,n_windows))
ER_n_array=np.zeros(n_channels)
#Alarm time array
alarm_time=np.zeros(n_channels)
#Detection time array
#detection_time=np.zeros(n_channels)
#A loop for every channel
for k in range(n_channels):
    #Getting the data of the channel
    data=raw.get_data()[k]
    #Getting the ER values
    time_points, ER_values = calculate_ER(data, window_size, overlap)
    ER=ER_values
    #Normalizing between 0 and 1
    ER=(ER-np.min(ER))/(np.max(ER)-np.min(ER))
    ER_matrix[k,:]=ER
    N=len(ER)
    ER_n=(1/N)*np.sum(ER)
    #Getting the ER_n values in the 
    ER_n_array[k]=ER_n
    ##Getting U_n
    U_n=np.zeros(len(ER))
    v=0.1
    u_min=0
    #lambda_=108867
    lambda_=125
    alarm_times=[]
    for i in range(N):
        U_n[i]=np.sum(ER[0:i]-ER_n-0.1)
        u_min=np.min(U_n)
        if (U_n[i]-u_min)>lambda_:
            #print('Anomaly detected at window number ',i, ' for channel ',k)    
            alarm_times.append(i)
            u_min=0
            U_n[i]=0
    #Saving the U_n values in the matrix
    U_n_matrix[k,:]=U_n
    #Getting the alarm time
    alarm_time[k]=alarm_times[0]
    
#Getting EI
N0=np.min(alarm_time)
Ei=[]
tau=1
#H variable is equal to 5 seconds, so 
H=5*fs
#sum from detection time to the end of the signal
for k in range(n_channels):
    Ei.append(((1/(alarm_time[k]-N0+tau))*np.sum(ER_matrix[k,int(alarm_time[k]):int(alarm_time[k]+H)])))


#Plotting the U_n values for every in a heatmap with an x axis of the window number and a y axis of the channel name with imshow variable "channels" as the labels
plt.imshow(ER_matrix,cmap='viridis',interpolation='bicubic',aspect='auto',extent=[0,40000,0,22])
#colorbar
plt.colorbar()
plt.yticks(np.arange(len(channels)), channels)
plt.xlabel('Window number')
plt.ylabel('Channel name')
plt.title('ER_n')
plt.show()

#Plting a barplt of the EI values for every channel
Ei_n=Ei/np.max(Ei)
plt.bar(channels,Ei)
plt.xlabel('Channel name')
plt.ylabel('EI')
plt.show()

# Connectivity Measures

In [None]:
#Making 3 slided windows of 200 seconds with 25% of overlap
epochs = mne.make_fixed_length_epochs(raw, duration=100, overlap=0.25,preload=True)
times=epochs.times
ch_names=epochs.ch_names

fmin, fmax = 4., 9.  # compute connectivity within 4-9 Hz
sfreq = raw.info['sfreq']  # sampling frequency
tmin = 0.0  # exclude the baseline period

# Compute PLI, wPLI, and dPLI
con_pli = spectral_connectivity_epochs(
    epochs, method='pli', mode='multitaper', sfreq=sfreq, fmin=fmin,
    fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1)

con_wpli = spectral_connectivity_epochs(
    epochs, method='wpli', mode='multitaper', sfreq=sfreq, fmin=fmin,
    fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1)

con_dpli = spectral_connectivity_epochs(
    epochs, method='dpli', mode='multitaper', sfreq=sfreq, fmin=fmin,
    fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1)





fig, axs = plt.subplots(1, 3, figsize=(14, 5), sharey=True)
axs[0].imshow(con_pli.get_data('dense'), vmin=0, vmax=1)
axs[0].set_title("PLI")
axs[0].set_ylabel("Node 1")
axs[0].set_xlabel("Node 2")

axs[1].imshow(con_wpli.get_data('dense'), vmin=0, vmax=1)
axs[1].set_title("wPLI")
axs[1].set_xlabel("Node 2")

im = axs[2].imshow(con_dpli.get_data('dense'), vmin=0, vmax=1)
axs[2].set_title("dPLI")
axs[2].set_xlabel("Node 2")

fig.colorbar(im, ax=axs.ravel())
plt.show()

In [None]:
fmin = (8., 13.)
fmax = (13., 30.)
tmin = 0.0 
sfreq = raw.info['sfreq']  # the sampling frequency

coh = spectral_connectivity_epochs(
    epochs, method='coh', mode='fourier', sfreq=sfreq, fmin=fmin,
    fmax=fmax, faverage=True, tmin=tmin, mt_adaptive=False, n_jobs=1)
freqs = coh.freqs

print('Frequencies in Hz over which coherence was averaged for alpha: ')
print(freqs[0])
print('Frequencies in Hz over which coherence was averaged for beta: ')
print(freqs[1])


#Plotting coherence matrix 
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey=True)
axs[0].imshow(coh.get_data('dense')[:,:,0], vmin=0, vmax=1)
axs[0].set_title("Alpha")
axs[0].set_ylabel("Node 1")
axs[0].set_xlabel("Node 2")

im = axs[1].imshow(coh.get_data('dense')[:,:,1], vmin=0, vmax=1)
axs[1].set_title("Beta")
axs[1].set_xlabel("Node 2")

fig.colorbar(im, ax=axs.ravel())
plt.show()

# Correlation for graph connection 

In [None]:
#Saving the data of every channel in a list of lists
data=[]
for i in range(0, len(chanels)):
    data.append(raw.get_data(picks=chanels[i]))

#Plotting the data of every channel in the same plot
for i in range(0, 15):
    plt.plot(data[i][0])
plt.show()

In [None]:
#Measuring connectivity between channels
#Pearson correlation
chanels=raw.ch_names
corr=[]
for i in range(0, len(chanels)):
    for j in range(0, len(chanels)):
        corr.append(np.corrcoef(data[i][0], data[j][0])[0][1])

#Plotting the correlation matrix
corr=np.array(corr)
corr=corr.reshape(len(chanels), len(chanels))
sns.heatmap(corr, square=True, vmin=-1, vmax=1, cmap='RdBu_r')

In [None]:
#To create a binary matrix, we use an arbitrary threshold in the correlation 
#matrix
a = Adjacency(corr, matrix_type='similarity', labels=[x for x in chanels])
a_thresholded = a.threshold(upper=.5, binarize=True)

a_thresholded.plot()

In [None]:
plt.figure(figsize=(20,15))
G = a_thresholded.to_graph()
pos = nx.kamada_kawai_layout(G)
node_and_degree = G.degree()
nx.draw_networkx_edges(G, pos, width=3, alpha=.2)
nx.draw_networkx_labels(G, pos, font_size=14, font_color='darkslategray')

nx.draw_networkx_nodes(G, pos, nodelist=list(dict(node_and_degree).keys()),
                       node_size=[x[1]*100 for x in node_and_degree],
                       node_color=list(dict(node_and_degree).values()),
                       cmap=plt.cm.Reds_r, linewidths=2, edgecolors='darkslategray', alpha=1)


In [None]:
plt.hist(dict(G.degree).values(), bins=20, color='lightseagreen', alpha=0.7)
plt.ylabel('Frequency', fontsize=18)
plt.xlabel('Degree', fontsize=18)

In [None]:
#Degree per channel plot
plt.figure(figsize=(20,15))
plt.barh(list(dict(G.degree).keys()), list(dict(G.degree).values()), color='lightseagreen')
plt.xlabel('Degree', fontsize=18)
plt.ylabel('Channel', fontsize=18)
plt.title('Degree per channel', fontsize=20)
plt.show()