In [None]:
!pip install pyedflib



In [None]:
import pyedflib
import numpy as np
from scipy.signal import cwt, morlet
import tensorflow as tf

def normalize_eeg(eeg_data):
    # Z-score normalization
    mean = np.mean(eeg_data, axis=1, keepdims=True)
    std = np.std(eeg_data, axis=1, keepdims=True)
    normalized_data = (eeg_data - mean) / std
    return normalized_data

def segment_eeg(normalized_data, segment_length, overlap):
    segments = []
    for i in range(0, len(normalized_data[0]) - segment_length + 1, overlap):
      segment = normalized_data[:, i:i+segment_length]
      segments.append(segment)
    segments = np.array(segments)
    return segments

def cwt_time_frequency(segmented_data, channel_idx, sample_rate):
    cwt_matrices = []
    for segment in segmented_data:
        # Apply Continuous Wavelet Transform (CWT) using Morlet wavelet
        cwt_matrix = cwt(segment[channel_idx], morlet, widths=np.arange(1, 100))
        cwt_matrices.append(cwt_matrix)
    return cwt_matrices

def cwt_time_frequency_plot(edf_file_paths, channel_idx, sample_rate):
    cwt_matrices_list = []

    for edf_file_path in edf_file_paths:
        # Load .edf file
        edf_file = pyedflib.EdfReader(edf_file_path)

        # EEG channel labels based on the 10-20 EEG montage

        
        eeg_channels = [
            "Fp1", "Fp2", "F7", "F3", "Fz", "F4", "F8",
            "T3", "C3", "Cz", "C4", "T4", "T5", "P3",
            "Pz", "P4", "T6", "O1", "O2"
        ]

        # Read EEG data for the specified channels
        eeg_data = [edf_file.readSignal(edf_file.getSignalLabels().index(channel)) for channel in eeg_channels]
        # Normalize EEG data
        normalized_data = normalize_eeg(np.array(eeg_data))

        # Parameters for segmentation
        segment_length = int(4 * sample_rate)  # 4 seconds * sample rate
        overlap = int(segment_length * 0.75)

        # Segment normalized data
        segments = segment_eeg(normalized_data, segment_length, overlap)

        # Choose a channel index for CWT visualization
        selected_channel_idx = channel_idx

        # Perform CWT and create time-frequency representation for segmented data
        cwt_matrices = cwt_time_frequency(segments, selected_channel_idx, sample_rate)

        cwt_matrices_list.extend(cwt_matrices)

        # Close .edf file
        edf_file.close()

    # Stack the CWT matrices along a new axis to combine them
    cwt_m = np.array(cwt_matrices_list)

    resized_images = []
    for sub_matrix in cwt_m:
        # Resize the image to 128x128 using bicubic interpolation
        sub_matrix = tf.abs(sub_matrix)
        sub_matrix = tf.expand_dims(sub_matrix, axis=2)
        resized_images.append(tf.image.resize(sub_matrix, (128, 128), method='bicubic'))

    resized_images = np.array(resized_images)

    return resized_images

# List of EDF file paths to combine
edf_file_paths = ["h01.edf","h02.edf","h03.edf","h04.edf","h05.edf","h06.edf","h07.edf","h08.edf","h09.edf","h10.edf","h11.edf","h12.edf","h13.edf","h14.edf"]  # Add all file paths
edf_file_paths1=["s01.edf","s02.edf","s03.edf","s04.edf","s05.edf","s06.edf","s07.edf","s08.edf","s09.edf","s10.edf","s11.edf","s12.edf","s13.edf","s14.edf"]
# Test the function with multiple files
# edf_file_paths=["h01.edf","h02.edf"]
# edf_file_paths1=["s01.edf","s02.edf"]

resized_image = cwt_time_frequency_plot(edf_file_paths, 0, 128)
resized_image1 = cwt_time_frequency_plot(edf_file_paths, 0, 128)

print(resized_image.shape)
print(resized_image1.shape)


(1193, 128, 128, 1)
(1193, 128, 128, 1)


In [None]:
import random

# Create a list of resized images
resized_images = []

# Generate random numbers for each image
random_numbers = [random.random() for i in range(len(resized_images))]

# Split the images into training and testing sets
X_train = []
Y_train = []
X_test = []
Y_test = []

for i, random_number in enumerate(random_numbers):
    if random_number <= 0.8:
        X_train.append(resized_images[i])
        Y_train.append(0)
    else:
        X_test.append(resized_images[i])
        Y_test.append(0)

random_numbers = [random.random() for i in range(len(resized_image1))]

for i, random_number in enumerate(random_numbers):
    if random_number <= 0.8:
        X_train.append(resized_image1[i])
        Y_train.append(1)
    else:
        X_test.append(resized_image1[i])
        Y_test.append(1)


X_train = np.array(X_train)
Y_train = np.array(Y_train)
X_test = np.array(X_test)
Y_test = np.array(Y_test)

for i in range(len(X_train)):
    j = random.randint(0, len(X_train) - 1)
    X_train[i], X_train[j] = X_train[j], X_train[i]
    Y_train[i], Y_train[j] = Y_train[j], Y_train[i]

for i in range(len(X_test)):
    j = random.randint(0, len(X_test) - 1)
    X_test[i],X_test[j] = X_test[j],X_test[i]
    Y_test[i], Y_test[j] = Y_test[j], Y_test[i]





print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)


(939, 128, 128, 1)
(939,)
(254, 128, 128, 1)
(254,)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Concatenate, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, Activation, Multiply, Conv2D, BatchNormalization, Add

# Input layer
input_layer = Input(shape=(128, 128, 1))

# First set of separable convolutions
conv1_1 = SeparableConv2D(16, (5, 5), padding='same', activation='relu')(input_layer)
conv1_2 = SeparableConv2D(16, (1, 1), padding='same', activation='relu')(conv1_1)
conv1_3 = SeparableConv2D(16, (3, 3), padding='same', activation='relu')(conv1_2)

# Concatenate first set of convolutions
concat1 = Concatenate()([conv1_1, conv1_2, conv1_3])

# Second set of separable convolutions
conv2_1 = SeparableConv2D(32, (3, 3), padding='same', activation='relu')(concat1)
conv2_2 = SeparableConv2D(32, (5, 5), padding='same', activation='relu')(concat1)
conv2_3 = SeparableConv2D(32, (1, 1), padding='same', activation='relu')(concat1)

# Third set of separable convolutions
conv3_1 = SeparableConv2D(32, (3,3), padding='same', activation='relu')(conv2_1)
conv3_2 = SeparableConv2D(32, (5,5), padding='same', activation='relu')(conv2_2)
conv3_3 = SeparableConv2D(32, (1,1), padding='same', activation='relu')(conv2_3)

# Concatenate second set of convolutions
concat2 = Concatenate()([conv3_1, conv3_2, conv3_3])

# Max pooling
max_pool = MaxPooling2D(pool_size=(2, 2))(concat2)

# Channel Wise Attention (CWA)
channels = 96  # Corrected the number of channels

# Global Average Pooling
global_pooling = GlobalAveragePooling2D()(max_pool)

# Fully Connected Layers with ReLU activation
fc1 = Dense(channels, activation='relu')(global_pooling)
fc2 = Dense(channels, activation='relu')(fc1)

# Sigmoid activation for channel-wise importance scores
channel_attention = Activation('sigmoid')(fc2)

# Expand dimensions for channel multiplication
channel_attention = tf.expand_dims(tf.expand_dims(channel_attention, axis=1), axis=1)

# Apply channel-wise attention by element-wise multiplication
output_feature_map_cwa = Multiply()([max_pool, channel_attention])

# Spatial Attention (SA)
kernel_size = 9

# First branch with 1x9x96 convolution and batch normalization
branch1 = Conv2D(96, (1, kernel_size), padding='same')(concat1)
branch1 = BatchNormalization()(branch1)
branch1 = Activation('relu')(branch1)

# Second branch with 9x1x96 convolution and batch normalization
branch2 = Conv2D(96, (kernel_size, 1), padding='same')(concat1)
branch2 = BatchNormalization()(branch2)
branch2 = Activation('relu')(branch2)

# Second set of convolutions for spatial aggregation with 9x1x1 and 1x9x1 kernels
branch1_2 = Conv2D(1, (kernel_size, 1), padding='same')(branch1)
branch1_2 = BatchNormalization()(branch1_2)
branch1_2 = Activation('relu')(branch1_2)

branch2_2 = Conv2D(1, (1, kernel_size), padding='same')(branch2)
branch2_2 = BatchNormalization()(branch2_2)
branch2_2 = Activation('relu')(branch2_2)

# Combine the two branches by element-wise addition
combined = Add()([branch1_2, branch2_2])

# Normalize using a sigmoid operation
sa_matrix = Activation('sigmoid')(combined)

# Apply the SA matrix to the input feature map
output_feature_map_sa = Multiply()([input_layer, sa_matrix])

# Resize SA feature map to match CWA spatial dimensions
output_feature_map_sa_resized = tf.image.resize(output_feature_map_sa, (64, 64), method='bilinear')

# Resize CWA feature map to match SA spatial dimensions
output_feature_map_cwa_resized = tf.image.resize(output_feature_map_cwa, (64, 64), method='bilinear')

# Concatenate SA with CWA
concat_sa_cwa = Concatenate()([output_feature_map_sa_resized, output_feature_map_cwa_resized])



# Another max pooling
max_pool3 = MaxPooling2D(pool_size=(2, 2))(concat_sa_cwa)

# Global average pooling
global_avg_pool = GlobalAveragePooling2D()(max_pool3)

# Dense layers
dense1 = Dense(128, activation='relu')(global_avg_pool)
dense2 = Dense(64, activation='relu')(dense1)

# Output layer
output_layer = Dense(2, activation='softmax')(dense2)

# Create the model
model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

# Display the model summary
model.summary()

# Compile the model
model.compile(optimizer='adam',  # You can choose a different optimizer
              loss='sparse_categorical_crossentropy',  # Use 'categorical_crossentropy' for one-hot encoded labels
              metrics=['accuracy'])



Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 separable_conv2d (SeparableCon  (None, 128, 128, 16  57         ['input_1[0][0]']                
 v2D)                           )                                                                 
                                                                                                  
 separable_conv2d_1 (SeparableC  (None, 128, 128, 16  288        ['separable_conv2d[0][0]']       
 onv2D)                         )                                                             

In [None]:
X_train = np.array(X_train)
Y_train = np.array(Y_train)
model.fit(X_train, Y_train, epochs=10, batch_size=32)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7e73ec88cc10>