In [1]:
import mne
from glob import glob
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from concurrent.futures import ThreadPoolExecutor
from tensorflow.keras.layers import Conv1D,BatchNormalization,LeakyReLU,MaxPool1D,\
GlobalAveragePooling1D,Dense,Dropout,AveragePooling1D
from tensorflow.keras.models import Sequential
from tensorflow.keras.backend import clear_session

In [2]:
def cnnmodel():
    clear_session()
    model=Sequential()
    model.add(Conv1D(filters=5,kernel_size=3,strides=1,input_shape=(3840,19)))#1
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#2
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#3
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#4
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#5
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#6
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#7
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#8
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#9
    model.add(LeakyReLU())
    model.add(GlobalAveragePooling1D())#10
    model.add(Dense(1,activation='sigmoid'))#11
    
    model.compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])
    return model

prebuild_model=cnnmodel()
prebuild_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv1d (Conv1D)             (None, 3838, 5)           290       
                                                                 
 batch_normalization (Batch  (None, 3838, 5)           20        
 Normalization)                                                  
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 3838, 5)           0         
                                                                 
 max_pooling1d (MaxPooling1  (None, 1919, 5)           0         
 D)                                                              
                                                                 
 conv1d_1 (Conv1D)           (None, 1917, 5)           80        
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 1917, 5)           0

In [3]:
all_files_path=glob('data/*.edf')
print(len(all_files_path))

162


In [4]:
healthy_file_path=[i for i in all_files_path if  'H' in i.split('\\')[1]]
patient_file_path=[i for i in all_files_path if  'M' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))

76 86


In [5]:
def read_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload=True)
    raw.pick_types(meg=False, eeg=True, eog=False, ecg=False) # Selecting EEG, EOG and ECG channels
    # Select a specific channel
    channel_to_keep = ['EEG Fp1-LE', 'EEG F3-LE', 'EEG C3-LE', 'EEG P3-LE', 'EEG O1-LE', 'EEG F7-LE', 'EEG T3-LE', 'EEG T5-LE', 'EEG Fz-LE', 'EEG Fp2-LE', 'EEG F4-LE', 'EEG C4-LE', 'EEG P4-LE', 'EEG O2-LE', 'EEG F8-LE', 'EEG T4-LE', 'EEG T6-LE', 'EEG Cz-LE', 'EEG Pz-LE']  
    # Replace with the name of the channel you want to keep
    raw.pick_channels(channel_to_keep)
    raw.set_eeg_reference()
    raw.filter(l_freq=30,h_freq=100)#1-4=delta, 4-8=theta, 8-12=alpha, 12-30=beta, 30-100=gamma
    epochs=mne.make_fixed_length_epochs(raw,duration=15,overlap=1)
    epochs=epochs.get_data()
    scaler = StandardScaler()
    data = scaler.fit_transform(epochs.reshape(-1,epochs.shape[-1])).reshape(epochs.shape)
    return data #trials,channel,length

In [6]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]

In [7]:
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]

data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)

In [8]:
data_array=np.moveaxis(data_array,1,2)

In [9]:
X=data_array
y=label_array

In [10]:
num_clients=10
clients_X, clients_y = np.array_split(X, num_clients), np.array_split(y, num_clients)

In [14]:
def train_client_model(model, data_X, data_y):
    model.fit(data_X,data_y,epochs=50,batch_size=32)
    return model

In [15]:
# Clone the global model for each client

for client_id in range(num_clients):
    globals()['client_model%s' % client_id] = tf.keras.models.clone_model(prebuild_model)
    globals()['client_model%s' % client_id].compile('adam',loss='binary_crossentropy',metrics=['Accuracy', 'Precision', 'Recall','AUC'])

In [16]:
with ThreadPoolExecutor() as executor:
    for i in range(num_clients):
        globals()['future%s' % i] = executor.submit(train_client_model, globals()['client_model%s' % i], clients_X[i], clients_y[i])

Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 1/50
Epoch 2/50
 3/15 [=====>........................] - ETA: 1s - loss: 0.6289 - Accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 0.0000e+000000e+0Epoch 2/50
 2/15 [===>..........................] - ETA: 2s - loss: 0.3990 - Accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 0.0000e+000000e+0Epoch 2/50
Epoch 3/50
Epoch 3/50
Epoch 3/50
Epoch 3/50
Epoch 2/50
Epoch 4/50
Epoch 4/50
Epoch 4/50
Epoch 2/50
 1/15 [=>............................] - ETA: 7s - loss: 0.5523 - Accuracy: 0.8125 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7051Epoch 2/50
Epoch 2/50
Epoch 2/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 5/50
Epoch 3/50
Epoch 3/50
Epoch 3/50
 1/15 [=>............................] - ETA: 5s - loss: 0.1814 - Accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 0.0000e+00Epoch 3/50
Epoch 3/50
Epoch 5/50
Epoch 6/50
Epoch 6/50
Epoch 6/50
Epoch

In [17]:
for j in range(num_clients):
    globals()['client_model%s' % j] = globals()['future%s' % j].result()
    # globals()['client_model%s' % j].save_model("model/"+ globals()['client_model%s' % j]+ ".h5")

In [42]:
client_model0.save_weights('Seungmin_Han/weights/weigths_client_model0.h5')

In [43]:
client_model0.get_weights()

[array([[[ 0.20919803,  0.2358667 ,  0.07248224, -0.15433463,
          -0.0953072 ],
         [ 0.12844926, -0.05449947,  0.14630486,  0.05820862,
           0.16501269],
         [-0.18257724,  0.23056433, -0.1733365 , -0.17856473,
           0.03433203],
         [-0.11583474, -0.11538161, -0.28434113,  0.32932487,
          -0.12160352],
         [ 0.11968027,  0.25155878,  0.16281198, -0.10712156,
           0.13138671],
         [-0.06983322, -0.31195128, -0.05817041,  0.19193402,
          -0.11747868],
         [-0.33135322, -0.01568634, -0.12727027,  0.11387625,
          -0.24641238],
         [ 0.3204249 ,  0.1918589 ,  0.05294228, -0.10341734,
          -0.25234485],
         [-0.15453029, -0.03104983,  0.2803006 , -0.2305263 ,
          -0.17139788],
         [-0.08953354, -0.12319523, -0.0939865 , -0.22825992,
           0.22863   ],
         [-0.08067454, -0.1354185 , -0.10021736, -0.02232257,
           0.06046665],
         [ 0.02473168,  0.11619309, -0.0181693 , -0.00