In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import h5py
import pandas as pd
import numpy as np
import tensorflow as tf
import os
from scipy.signal import butter, filtfilt
import matplotlib.pyplot as plt
import gc
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, Activation, BatchNormalization
from sklearn.model_selection import GridSearchCV
from keras.utils import to_categorical
from tensorflow import keras
import pywt
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.layers import Conv2D, ELU, ZeroPadding2D, MaxPooling2D
from tensorflow.keras import regularizers


In [3]:
def load_data(file_path):
    with h5py.File(file_path, 'r') as f:
        dataset_name = get_dataset_name(file_path)
        matrix = f.get(dataset_name)[:]
        return matrix

def get_dataset_name(file_name_with_dir):
    filename_without_dir = file_name_with_dir.split('/')[-1]
    temp = filename_without_dir.split('_')[:-1]
    dataset_name = "_".join(temp)
    return dataset_name

def load_data_by_task(data_folder):
    meg_data_list = []
    labels = []

    for file in os.listdir(data_folder):
        if file.endswith('.h5'):
            file_path = os.path.join(data_folder, file)
            data = load_data(file_path)
            if data is not None:
                meg_data_list.append(data)
                label = assign_label(file)
                labels.append(label)

    if meg_data_list:
        meg_data_array = np.stack(meg_data_list, axis=0)
        labels_array = np.array(labels)
        return meg_data_array, labels_array
    else:
        return None, None

def assign_label(file_name):
    if file_name.startswith("rest"):
        return 0
    elif file_name.startswith("task_motor"):
        return 1
    elif file_name.startswith("task_story"):
        return 2
    elif file_name.startswith("task_working"):
        return 3
    else:
        return None

def find_fmri_data_folder(start_path):
    for root, dirs, files in os.walk(start_path):
        if 'meg_data' in dirs:
            return os.path.join(root, 'meg_data/Intra/train')
    raise Exception("meg_data folder not found. Please check the directory structure.")


def butter_lowpass_filter(data, cutoff, fs, order=5):
    nyq = 0.5 * fs  # Nyquist Frequency
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y



def apply_scaling(array):
  array_norm = np.zeros((array.shape[0],array.shape[1],array.shape[2]))
  for i in range(array.shape[0]):
    means = np.mean(array[i], axis=1)  # Calculate mean for each sensor
    stds = np.std(array[i], axis=1)    # Calculate standard deviation for each sensor
    array_norm[i] = (array[i] - means[:, np.newaxis]) / stds[:, np.newaxis]   # Subtrack and divide

  del array, means, stds
  gc.collect()
  return array_norm


def apply_lowpass(array):
  original_sampling_rate = 2034  # Original sampling rate
  downsampling_factor = 8
  new_sampling_rate = original_sampling_rate / downsampling_factor  # New sampling rate after downsampling
  cutoff_frequency = new_sampling_rate / 2  # Nyquist frequency

  array_filtered = np.zeros_like(array)

  for o in range(array.shape[0]):
      for i in range(array.shape[1]):
          array_filtered[o, i, :] = butter_lowpass_filter(array[o, i, :], cutoff_frequency, original_sampling_rate)

  del array
  return array_filtered


def apply_downsampling(array):
  n_observations, n_sensors, n_timepoints = array.shape
  new_n_timepoints = n_timepoints // 8
  downsampling_factor = 8
  array_downsamp = np.zeros((n_observations, n_sensors, new_n_timepoints))

  for obs in range(n_observations):
    for sensor in range(n_sensors):
      array_downsamp[obs,sensor,:] = array[obs, sensor, ::downsampling_factor]

  del array, new_n_timepoints, downsampling_factor, n_observations, n_sensors, n_timepoints
  gc.collect()
  return array_downsamp




def apply_wavelet_transform(data, wavelet='db4', level=5, original_length=8906):
    transformed_data = np.zeros(data.shape)  # Initialize array to maintain original shape

    for i in range(data.shape[0]):  # Loop over observations
        for j in range(data.shape[1]):  # Loop over sensors
            # Apply wavelet transform to the sensor's data
            coeffs = pywt.wavedec(data[i, j, :], wavelet, level=level)
            # Concatenate and pad the coefficients to match original length
            concatenated_coeffs = np.concatenate(coeffs)
            pad_length = original_length - concatenated_coeffs.shape[0]
            if pad_length > 0:
                concatenated_coeffs = np.pad(concatenated_coeffs, (0, pad_length), 'constant')
            else:
                concatenated_coeffs = concatenated_coeffs[:original_length]
            # Assign the transformed data
            transformed_data[i, j, :] = concatenated_coeffs

    return transformed_data


fmri_data_folder = find_fmri_data_folder('/content/drive/My Drive')

tasks = ['rest', 'task_motor', 'task_story', 'task_working']
task_numbers = ['_1.','_2.','_3.','_4.','_5.','_6.','_7.','_8.']
visual_data = []

X_task, y_task = load_data_by_task(fmri_data_folder)



In [4]:
if X_task is not None and y_task is not None:
    print("Train shape:", X_task.shape)
    print("Labels shape:", y_task.shape)

    ##########
    visual_data.append(X_task[0,0])

    # Scaling
    X_task_norm = apply_scaling(X_task)
    print("X_task_norm", X_task_norm.shape)
    del X_task
    gc.collect()
    ##########
    visual_data.append(X_task_norm[0,0])

    # Lowpass filter
    X_task_filtered = apply_lowpass(X_task_norm)
    print("X_task_filtered", X_task_filtered.shape)
    del X_task_norm
    gc.collect()
    ##########
    visual_data.append(X_task_filtered[0,0])

    # Downsample
    X_task_downsamp = apply_downsampling(X_task_filtered)
    print("X_task_downsamp", X_task_downsamp.shape)
    del X_task_filtered
    gc.collect()
    ##########
    visual_data.append(X_task_downsamp[0,0])

    print("After downsampling:", X_task_downsamp.shape)

    # ----- Train the model on this task's data ----- #
    obs_train, sensors_train, points_train = X_task_downsamp.shape

    X_train = np.expand_dims(X_task_downsamp, axis=3)
    print("X_train shape:", X_train.shape)

    y_train_encoded = to_categorical(y_task, num_classes=4)

    del y_task, obs_train, visual_data
    gc.collect()

Train shape: (32, 248, 35624)
Labels shape: (32,)
X_task_norm (32, 248, 35624)
X_task_filtered (32, 248, 35624)
X_task_downsamp (32, 248, 4453)
After downsampling: (32, 248, 4453)
X_train shape: (32, 248, 4453, 1)


In [7]:
def EEGNet():
    model = Sequential()

    # Layer 1
    model.add(Conv2D(8, (1, 64), input_shape=(248, 4453,1), padding='valid'))
    model.add(BatchNormalization(axis=1))
    model.add(ELU())
    model.add(Dropout(0.25))
    model.add(MaxPooling2D(pool_size=(1, 4)))
    # No permute layer in Keras; adjust as needed

    # Layer 2
    model.add(ZeroPadding2D(padding=((0, 1), (16, 17))))  # Adjusted padding
    model.add(Conv2D(4, (2, 32), padding='valid'))
    model.add(BatchNormalization(axis=1))
    model.add(ELU())
    model.add(Dropout(0.25))
    model.add(MaxPooling2D(pool_size=(2, 4)))

    # Layer 3
    model.add(ZeroPadding2D(padding=((4, 3), (2, 1))))  # Adjusted padding
    model.add(Conv2D(4, (8, 4), padding='valid'))
    model.add(BatchNormalization(axis=1))
    model.add(ELU())
    model.add(Dropout(0.25))
    model.add(MaxPooling2D(pool_size=(2, 4)))

    # FC Layer
    model.add(Flatten())
    model.add(Dense(4, activation='softmax', kernel_regularizer=regularizers.l2(0.1)))  # Output layer for 4 classes

    return model

In [8]:
# Create the model
net = EEGNet()

# Compile the model
net.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

history = net.fit(X_train,y_train_encoded , epochs=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
