In [22]:
import os
import numpy as np
import scipy.io as sio
import mne
from scipy.signal import resample

In [None]:
SAMPLE_RATE = 128  # fs
SAMPLE_LEN = 128   # T

## Official script

In [23]:
filenames = []
for filename in os.listdir("APAVA/"):
  filenames.append(filename)

In [24]:
filenames.sort()
# filenames

In [25]:
feature_path = 'Processed/APAVA-19/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)

In [26]:
def interpolate_to_19_channels(eeg_data, input_channels, bad_channels, sfreq=256, montage='standard_1020'):
    """
    Interpolate the input eeg_data to 19 channels with standard order.
    paras:
        eeg_data (numpy.ndarray): Input shape (T, C).
        input_channels (list): list less than 19. Such as ['F3', 'Fz', 'F4', 'Cz', 'P3', 'Pz', 'P4']
        badd_channels (list): list of channels that need to be interpolated.
        sfreq (int): sampling frequency, default 256 Hz。
        montage (str): 'standard_1020'
    
    return:
        numpy.ndarray:  (N, 19)，return interpolated_data with standard order.
    """
    # DESIRED_ORDER
    desired_order = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 
                     'C4', 'T4', 'T5', 'P3', 'Pz', 'P4', 'T6', 'O1', 'O2']
    # SET_DATA
    temp = np.zeros((19, eeg_data.shape[0]))
    for idx, _ in enumerate(input_channels):
        temp[idx] = eeg_data[:, idx]
    # CREATE_INFO
    info = mne.create_info(ch_names=input_channels+bad_channels, sfreq=sfreq, ch_types='eeg')
    raw = mne.io.RawArray(temp, info)
    # STANDARD_MONTAGE
    montage = mne.channels.make_standard_montage(montage)
    raw.set_montage(montage)
    # MARK_BADS
    raw.info['bads'] = bad_channels
    # INTERPOLATE_BADS
    raw.interpolate_bads(reset_bads=True)
    # REORDER_CHANNELS
    raw.reorder_channels(desired_order)
    # GET_DATA
    interpolated_data = raw.get_data().T
    return interpolated_data

#### Save feature

In [27]:
def resample_time_series(data, original_fs, target_fs):
    T, C = data.shape
    new_length = int(T * target_fs / original_fs)
    
    resampled_data = np.zeros((new_length, C))
    for i in range(C):
        resampled_data[:, i] = resample(data[:, i], new_length)
        
    return resampled_data

In [28]:
subseq_length = SAMPLE_LEN
stride = SAMPLE_LEN / 2  # Half of the subsequence length for half-overlapping
for i in range(len(filenames)):
    # print('Dataset/'+filename)
    path = "APAVA/" + filenames[i]
    mat = sio.loadmat(path)
    mat_np = mat['data']

    # Get epoch number for each subject
    epoch_num = len(mat_np[0,0][2][0])
    print("Epoch number: ",epoch_num)
    # Each epoch has shape (1280, 16)
    raw_shape = np.zeros((epoch_num, 1280, 16)).shape
    features = []
    # Store in temp
    for j in range(epoch_num):
        temp = np.transpose(mat_np[0,0][2][0][j])
        
        # ['Fz', 'Cz', 'Pz'] does not exist, we need to interpolate them
        input_channels = ['C3', 'C4', 'F3', 'F4', 'F7', 'F8', 'Fp1', 'Fp2', 'O1', 'O2', 'P3', 'P4', 'T3', 'T4', 'T5', 'T6']
        bad_channels = ['Fz', 'Cz', 'Pz']
        temp = interpolate_to_19_channels(temp, input_channels, bad_channels, sfreq=256)
        
        data = resample_time_series(temp, 256, SAMPLE_RATE)  # Downsample to 128 Hz
        # Calculate the number of subsequences that can be extracted
        num_subsequences = (data.shape[0] - subseq_length) // stride + 1
        # Extract the subsequences
        subsequences = [data[i * stride : i * stride + subseq_length, :] for i in range(num_subsequences)]
        feature = np.array(subsequences)
        features.append(feature)
    features = np.array(features).reshape((-1, subseq_length, 19))   # Reshape to (N, 128, 19)

    print(f"Filename: {filenames[i]}")
    print(f"Patient ID: {i+1}")
    print("Raw data:", raw_shape)
    print("Downsampling, segmented and interpolated data", features.shape)
    np.save(feature_path + "/feature_{:02d}.npy".format(i+1),features)
    print("Save feature_{:02d}.npy".format(i+1))
    print("---------------------------------------------\n")

Epoch number:  35
Creating RawArray with float64 data, n_channels=19, n_times=1280
    Range : 0 ... 1279 =      0.000 ...     4.996 secs
Ready.
Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 95.4 mm
Computing interpolation matrix from 16 sensor positions
Interpolating 3 sensors
Creating RawArray with float64 data, n_channels=19, n_times=1280
    Range : 0 ... 1279 =      0.000 ...     4.996 secs
Ready.
Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 95.4 mm
Computing interpolation matrix from 16 sensor positions
Interpolating 3 sensors
Creating RawArray with float64 data, n_channels=19, n_times=1280
    Range : 0 ... 1279 =      0.000 ...     4.996 secs
Ready.
Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 95.4 mm
Computing interpolation matrix fr

#### Save label

In [29]:
AD_positive = [1,3,6,8,9,11,12,13,15,17,19,21]

In [30]:
labels = np.zeros((23, 2))
len(labels)

23

In [31]:
label_path = 'Processed/APAVA-19/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [32]:
for i in range(len(labels)):
  # The first one is AD label (0 for healthy; 1 for AD patient)
  # The second one is the subject label (the order of subject, ranging from 1 to 23.
  labels[i][1] = i + 1
  if i+1 in AD_positive:
    labels[i][0] = 1
  else:
    labels[i][0] = 0

In [33]:
np.save(label_path + "/label.npy",labels)
print("Save label")

Save label


## Test

In [34]:
# Test the saved npy file
# example

path = 'Processed/APAVA-19/Feature/'

for file in os.listdir(path):
    sub_path = os.path.join(path, file)
    print(np.load(sub_path).shape)

(315, 128, 19)
(225, 128, 19)
(90, 128, 19)
(297, 128, 19)
(9, 128, 19)
(198, 128, 19)
(27, 128, 19)
(288, 128, 19)
(162, 128, 19)
(342, 128, 19)
(423, 128, 19)
(333, 128, 19)
(261, 128, 19)
(351, 128, 19)
(414, 128, 19)
(252, 128, 19)
(531, 128, 19)
(360, 128, 19)
(414, 128, 19)
(333, 128, 19)
(171, 128, 19)
(153, 128, 19)
(18, 128, 19)


In [35]:
np.load("Processed/APAVA-19/Label/label.npy")

array([[ 1.,  1.],
       [ 0.,  2.],
       [ 1.,  3.],
       [ 0.,  4.],
       [ 0.,  5.],
       [ 1.,  6.],
       [ 0.,  7.],
       [ 1.,  8.],
       [ 1.,  9.],
       [ 0., 10.],
       [ 1., 11.],
       [ 1., 12.],
       [ 1., 13.],
       [ 0., 14.],
       [ 1., 15.],
       [ 0., 16.],
       [ 1., 17.],
       [ 0., 18.],
       [ 1., 19.],
       [ 0., 20.],
       [ 1., 21.],
       [ 0., 22.],
       [ 0., 23.]])