In [1]:
import mne
from glob import glob
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
import numpy as np
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from concurrent.futures import ThreadPoolExecutor

In [2]:
prebuild_model = tf.keras.models.load_model("models/gamma_15s_50epoch_32batch.h5")

In [3]:
global_model = tf.keras.models.clone_model(prebuild_model)
local_model = tf.keras.models.clone_model(prebuild_model)

In [4]:
global_model.get_weights()

[array([[[ 0.02816325,  0.15715706,  0.02133852,  0.22156346,
           0.15642938],
         [ 0.27488345,  0.20100933, -0.2570672 , -0.2093944 ,
          -0.19707787],
         [ 0.13406494,  0.01726156, -0.22177693, -0.05798385,
          -0.22747459],
         [ 0.09898734, -0.06253535, -0.19950631, -0.21455342,
          -0.25312325],
         [-0.13595888, -0.25165692, -0.02322572,  0.08385831,
          -0.18754375],
         [-0.12383203,  0.28745425,  0.24576849, -0.01905242,
          -0.12666263],
         [ 0.0321548 ,  0.04893631, -0.12766705,  0.06343153,
          -0.0578707 ],
         [ 0.13878933, -0.07673205,  0.07820114, -0.18228981,
           0.21056885],
         [-0.07793842, -0.07349086, -0.12969851, -0.10168543,
          -0.02606878],
         [-0.17187685,  0.27068895, -0.27083066, -0.16891013,
           0.11887419],
         [ 0.18114305, -0.2589288 , -0.12340793, -0.14055932,
           0.14748517],
         [-0.28244373,  0.1863859 , -0.05265331, -0.01

In [None]:
local_model.save_weights("try1_32batch_15s_50epoch/local_model_weights.weights.h5")

In [None]:
global_model.save_weights("try1_32batch_15s_50epoch/global_model_weights.weights.h5")

In [4]:
all_files_path=glob('sd_copy/*.edf')
print(len(all_files_path))

162


In [5]:
healthy_file_path=[i for i in all_files_path if  'H' in i.split('\\')[1]]
patient_file_path=[i for i in all_files_path if  'M' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))

76 86


In [6]:
def read_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload=True)
    raw.pick_types(meg=False, eeg=True, eog=False, ecg=False) # Selecting EEG, EOG and ECG channels
    # Select a specific channel
    channel_to_keep = ['EEG Fp1-LE', 'EEG F3-LE', 'EEG C3-LE', 'EEG P3-LE', 'EEG O1-LE', 'EEG F7-LE', 'EEG T3-LE', 'EEG T5-LE', 'EEG Fz-LE', 'EEG Fp2-LE', 'EEG F4-LE', 'EEG C4-LE', 'EEG P4-LE', 'EEG O2-LE', 'EEG F8-LE', 'EEG T4-LE', 'EEG T6-LE', 'EEG Cz-LE', 'EEG Pz-LE']  
    # Replace with the name of the channel you want to keep
    raw.pick_channels(channel_to_keep)
    raw.set_eeg_reference()
    raw.filter(l_freq=30,h_freq=100)#1-4=delta, 4-8=theta, 8-12=alpha, 12-30=beta, 30-100=gamma
    epochs=mne.make_fixed_length_epochs(raw,duration=15,overlap=1)
    epochs=epochs.get_data()
    scaler = StandardScaler()
    data = scaler.fit_transform(epochs.reshape(-1,epochs.shape[-1])).reshape(epochs.shape)
    return data #trials,channel,length

In [7]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)

In [8]:
print(data_array.shape,label_array.shape)

(4638, 19, 3840) (4638,)


In [9]:
data_array=np.moveaxis(data_array,1,2)

In [10]:
X=data_array
y=label_array

In [None]:
# y=tf.keras.utils.to_categorical(y)

In [11]:
y.shape

(4638,)

In [12]:
num_clients=3
clients_X, clients_y = np.array_split(X, num_clients), np.array_split(y, num_clients)
clients_group_array = np.array_split(group_array, num_clients)

In [13]:
def train_client_model(model, data_X, data_y, group_data):
    data_X=data_X
    data_y=data_y
    Group_data= group_data
    gkf=GroupKFold(n_splits=2)
    for train_index, val_index in gkf.split(data_X, data_y, groups=group_data):
        train_features,train_labels=data_X[train_index],data_y[train_index]
        val_features,val_labels=data_X[val_index],data_y[val_index]
        scaler=StandardScaler()
        train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
        val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)
        # model=cnnmodel()
        model.fit(train_features,train_labels,epochs=5,batch_size=25,validation_data=(val_features,val_labels))
        evaluation_results = model.evaluate(val_features,val_labels)

    return model

In [16]:
# Clone the global model for each client

for client_id in range(num_clients):
    globals()['client_model%s' % client_id] = tf.keras.models.clone_model(local_model)
    globals()['client_model%s' % client_id].compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])


In [18]:

with ThreadPoolExecutor() as executor:
    for i in range(num_clients):
        globals()['future%s' % i] = executor.submit(train_client_model, globals()['client_model%s' % i], clients_X[i], clients_y[i],clients_group_array[i])
    


Epoch 1/5
Epoch 1/5
Epoch 1/5
Epoch 2/5
Epoch 2/5
Epoch 2/5
Epoch 3/5
 2/31 [>.............................] - ETA: 3s - loss: 0.5057 - Accuracy: 1.0000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.0000e+00Epoch 3/5
 2/31 [>.............................] - ETA: 4s - loss: 0.4847 - Accuracy: 0.8600 - precision: 0.8409 - recall: 1.0000 - auc: 0.9168Epoch 3/5
Epoch 4/5
 2/31 [>.............................] - ETA: 3s - loss: 0.1481 - Accuracy: 1.0000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.0000e+00Epoch 4/5
Epoch 4/5
Epoch 5/5
Epoch 5/5
 4/31 [==>...........................] - ETA: 5s - loss: 0.2326 - Accuracy: 0.9100 - precision: 0.8657 - recall: 1.0000 - auc: 0.9762Epoch 5/5
Epoch 1/5
Epoch 1/5
Epoch 2/5
Epoch 2/5
 3/31 [=>............................] - ETA: 4s - loss: 0.0032 - Accuracy: 1.0000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.0000e+00Epoch 1/5
Epoch 3/5
Epoch 3/5
Epoch 2/5
Epoch 4/5
Epoch 4/5
Epoch 3/5
Epoch 5/5
Epoch 5/5
Epoch 5/5


In [21]:

for j in range(num_clients):
    globals()['client_model%s' % j] = globals()['future%s' % j].result()


In [25]:
client_model2.get_weights()

[array([[[ 0.23915273,  0.09087033,  0.10885888, -0.1566909 ,
          -0.1552707 ],
         [-0.0223843 ,  0.08566833, -0.01153228,  0.06614044,
          -0.2022529 ],
         [ 0.05676519,  0.09741896,  0.306782  ,  0.05215501,
           0.03201086],
         [-0.20256838, -0.23985922, -0.07778507, -0.28046748,
          -0.06665123],
         [ 0.108548  , -0.32049042,  0.06765447, -0.2419018 ,
          -0.2133915 ],
         [ 0.22083935, -0.16632928, -0.09049753, -0.15297529,
           0.27921727],
         [ 0.05169789, -0.09425769,  0.13482223, -0.06618448,
           0.08597369],
         [ 0.09054499,  0.03898213, -0.16376656,  0.10294817,
          -0.16684765],
         [-0.29140592, -0.20506375, -0.21247613, -0.05688139,
           0.16908535],
         [-0.30069202, -0.11855864, -0.03729041,  0.13993913,
          -0.07008845],
         [ 0.10670479,  0.04547297,  0.25404945,  0.04185706,
           0.24598776],
         [-0.21886027,  0.33465734, -0.12155315,  0.30