In [1]:
import mne
from glob import glob
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
import numpy as np
import random
from sklearn.model_selection import train_test_split
from concurrent.futures import ThreadPoolExecutor

from tensorflow.keras.layers import Conv1D,BatchNormalization,LeakyReLU,MaxPool1D,\
GlobalAveragePooling1D,Dense,Dropout,AveragePooling1D
from tensorflow.keras.models import Sequential
from tensorflow.keras.backend import clear_session

In [2]:
all_files_path=glob('../data/*.edf')
print(len(all_files_path))
healthy_file_path=[i for i in all_files_path if  'H' in i.split('\\')[1]]
patient_file_path=[i for i in all_files_path if  'M' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))
def read_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload=True)
    raw.pick_types(meg=False, eeg=True, eog=False, ecg=False) # Selecting EEG, EOG and ECG channels
    # Select a specific channel
    channel_to_keep = ['EEG Fp1-LE', 'EEG F3-LE', 'EEG C3-LE', 'EEG P3-LE', 'EEG O1-LE', 'EEG F7-LE', 'EEG T3-LE', 'EEG T5-LE', 'EEG Fz-LE', 'EEG Fp2-LE', 'EEG F4-LE', 'EEG C4-LE', 'EEG P4-LE', 'EEG O2-LE', 'EEG F8-LE', 'EEG T4-LE', 'EEG T6-LE', 'EEG Cz-LE', 'EEG Pz-LE', 'EEG A2-A1']  
    # Replace with the name of the channel you want to keep
    raw.pick_channels(channel_to_keep)
    raw.set_eeg_reference()
    raw.filter(l_freq=30,h_freq=100)#1-4=delta, 4-8=theta, 8-12=alpha, 12-30=beta, 30-100=gamma
    epochs=mne.make_fixed_length_epochs(raw,duration=15,overlap=1)
    epochs=epochs.get_data()
    scaler = StandardScaler()
    data = scaler.fit_transform(epochs.reshape(-1,epochs.shape[-1])).reshape(epochs.shape)
    return data #trials,channel,length

162
76 86


In [3]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]

In [4]:
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]

# combined_list = [item for pair in zip(X, y) for item in pair]
combined_list = [[a, b] for a, b in zip(data_list, label_list)]

# Shuffle the combined pairs randomly
random.shuffle(combined_list)
data_list = [pair[0] for pair in combined_list]
label_list = [pair[1] for pair in combined_list]

In [5]:
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)

data_array=np.moveaxis(data_array,1,2)
X=data_array
y=label_array
# print(X[4500],y[4500])

In [6]:
def cnnmodel():
    clear_session()
    model=Sequential()
    model.add(Conv1D(filters=5,kernel_size=3,strides=1,input_shape=(3840,20)))#1
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#2
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#3
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#4
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#5
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#6
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#7
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#8
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#9
    model.add(LeakyReLU())
    model.add(GlobalAveragePooling1D())#10
    model.add(Dense(1,activation='sigmoid'))#11
    
    model.compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])
    return model

In [7]:
def load_data(num_clients,X,y):
    data_size = len(X)
    subset_size = data_size // num_clients

    clients_X = [X[i:i + subset_size] for i in range(0, data_size, subset_size)]
    clients_y = [y[i:i + subset_size] for i in range(0, data_size, subset_size)]

    # clients_X, clients_y = np.array_split(X, num_clients), np.array_split(y, num_clients)
    return clients_X, clients_y

In [8]:
def train_local_model(model, data_X, data_y):
    model.compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])
    history = model.fit(data_X,data_y,epochs=25,batch_size=25)
    return model, history

In [9]:
# Initialize global model
global_model = cnnmodel()
global_model_for_loss = tf.keras.models.clone_model(global_model)
global_model_for_loss.compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])

# Number of devices
num_devices = 10


# Number of communication rounds
num_communication_rounds = 1

clients_X, clients_y=load_data(num_devices,X,y)

In [10]:
global_model.get_weights()

[array([[[-2.56307960e-01, -2.05308974e-01, -2.39813954e-01,
          -1.36033043e-01,  2.35177010e-01],
         [-9.40822214e-02, -1.09601229e-01, -1.90509558e-01,
          -1.93158343e-01, -1.63379267e-01],
         [ 2.41241157e-02, -8.68678242e-02, -1.95634812e-01,
          -6.82789087e-02, -1.16551548e-01],
         [-3.98096293e-02,  7.12196231e-02,  1.94320440e-01,
          -9.82287973e-02, -5.74060977e-02],
         [-5.41282892e-02, -2.69900620e-01, -1.41607955e-01,
           1.98347867e-01, -2.68821120e-01],
         [-5.48277944e-02,  1.20708376e-01, -2.68857241e-01,
           2.06749737e-01, -2.54667521e-01],
         [ 1.50589824e-01,  2.22161025e-01, -6.14577979e-02,
           2.02693969e-01, -2.11297199e-01],
         [-1.06286675e-01, -9.87841189e-02, -2.03423023e-01,
           5.62062860e-02,  2.51221329e-01],
         [ 2.04126835e-01,  2.57153064e-01,  3.28634977e-02,
          -2.71559834e-01, -8.56337696e-02],
         [ 7.90646672e-03,  1.49299979e-01, -1

In [11]:
# print(len(clients_X[10]),len(clients_y[10]))

In [12]:
total_loss=0.00
# Federated Learning
for round in range(num_communication_rounds):
    local_models = []

    # Communication Round
    for i in range(num_devices):
        # Load data on each device
        local_data_X = clients_X[i]
        local_data_y=clients_y[i]

        # Clone the global model for each device
        local_model = tf.keras.models.clone_model(global_model)
        local_model.set_weights(global_model.get_weights())

        # Train the local model on local data
        # local_model.compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])
        local_model,history = train_local_model(local_model, local_data_X, local_data_y)
        local_model.save('models/local_model_%s.h5' % i)
        # Save the local model for aggregation
        local_models.append((local_model, history.history['loss'], len(local_data_y)))


Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


  saving_api.save_model(


Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoc

In [13]:
global_model.get_weights()

[array([[[-2.56307960e-01, -2.05308974e-01, -2.39813954e-01,
          -1.36033043e-01,  2.35177010e-01],
         [-9.40822214e-02, -1.09601229e-01, -1.90509558e-01,
          -1.93158343e-01, -1.63379267e-01],
         [ 2.41241157e-02, -8.68678242e-02, -1.95634812e-01,
          -6.82789087e-02, -1.16551548e-01],
         [-3.98096293e-02,  7.12196231e-02,  1.94320440e-01,
          -9.82287973e-02, -5.74060977e-02],
         [-5.41282892e-02, -2.69900620e-01, -1.41607955e-01,
           1.98347867e-01, -2.68821120e-01],
         [-5.48277944e-02,  1.20708376e-01, -2.68857241e-01,
           2.06749737e-01, -2.54667521e-01],
         [ 1.50589824e-01,  2.22161025e-01, -6.14577979e-02,
           2.02693969e-01, -2.11297199e-01],
         [-1.06286675e-01, -9.87841189e-02, -2.03423023e-01,
           5.62062860e-02,  2.51221329e-01],
         [ 2.04126835e-01,  2.57153064e-01,  3.28634977e-02,
          -2.71559834e-01, -8.56337696e-02],
         [ 7.90646672e-03,  1.49299979e-01, -1

In [14]:
# Aggregation (FedAvg)
global_weights_sum_for_loss = [tf.zeros_like(w) for w in global_model.get_weights()]
global_weights_sum = [tf.zeros_like(w) for w in global_model.get_weights()]
global_weights_sum

[<tf.Tensor: shape=(3, 20, 5), dtype=float32, numpy=
 array([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
 
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 

In [15]:
total_sample=0
total_loss=0
for local_model, loss_list, num_samples_i in local_models:
    local_weights = local_model.get_weights()

    loss=np.mean(loss_list)
    print(loss)
    total_loss +=(1-loss)
    print(total_loss)
    weighted_local_weights_for_loss = [(1-loss) * w for w in local_weights]
    global_weights_sum_for_loss = [tf.add(gw_l, wlw_l) for gw_l, wlw_l in zip(global_weights_sum_for_loss, weighted_local_weights_for_loss)]

    total_sample+=num_samples_i
    print(total_sample)
    weighted_local_weights = [num_samples_i * w for w in local_weights]
    global_weights_sum = [tf.add(gw, wlw) for gw, wlw in zip(global_weights_sum, weighted_local_weights)]

0.26122283226810394
0.7387771677318961
463
0.23204609744250773
1.5067310702893884
926
0.2803817458450794
2.2263493244443087
1389
0.17176048252731563
3.054588841916993
1852
0.18331520996987818
3.8712736319471146
2315
0.2092430473305285
4.662030584616586
2778
0.16145848609972746
5.500572098516859
3241
0.18158126343041658
6.318990835086442
3704
0.16582335541024804
7.153167479676194
4167
0.19451912712305786
7.958648352553136
4630


In [16]:
average_weights_for_loss = [tf.divide(gws_l, total_loss) for gws_l in global_weights_sum_for_loss]
global_model_for_loss.set_weights(average_weights_for_loss)

average_weights = [tf.divide(gws, total_sample) for gws in global_weights_sum]
global_model.set_weights(average_weights)

In [17]:
global_model.save('models/global_model.h5')
global_model_for_loss.save('models/global_model_for_loss.h5')