In [117]:
import numpy as np
import scipy.io
from scipy.interpolate import interp1d
import scipy.signal
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, Dataset

import mne
from mne import create_info
from mne.io import RawArray
from mne.time_frequency import tfr_array_morlet

In [60]:
L_FREQ, H_FREQ = 40, 300 # Lower and upper filtration bounds
CHANNELS_NUM = 62        # Number of channels in ECoG data
WAVELET_NUM = 40         # Number of wavelets in the indicated frequency range, with which the convolution is performed
SAMPLE_RATE = 100        # Final sampling rate for finger flex and ecog data
time_delay_secs = 0.2    # Time delay hyperparameter

In [61]:
def finger_flex_preprocess(finger_flex, original_hz=1000, actual_hz=25, upsampled_hz=SAMPLE_RATE):

    downscaling_ratio = original_hz // actual_hz
    finger_flex = finger_flex[:, ::downscaling_ratio]

    print(upsampled_hz)

    # Time vectors for interpolation
    old_time = np.linspace(0, finger_flex.shape[1] - 1, num=finger_flex.shape[1])
    new_time = np.linspace(0, finger_flex.shape[1] - 1, num=int(finger_flex.shape[1] * (upsampled_hz / actual_hz)))

    # Apply cubic interpolation
    interpolator = interp1d(old_time, finger_flex, kind="cubic", axis=1)
    finger_flex_interp = interpolator(new_time)

    scaler = MinMaxScaler(feature_range=(0, 1))
    finger_flex_scaled = scaler.fit_transform(finger_flex_interp.T).T

    return finger_flex_scaled


In [62]:
def ecog_preprocess(train):
    original_sfreq = 1000
    ch_names = [f"ch{i}" for i in range(1, 63)]
    ch_types = ["ecog"] * 62  # Mark all channels as ECoG

    info = create_info(ch_names=ch_names, sfreq=original_sfreq, ch_types=ch_types)

    raw = RawArray(train, info)
    raw.filter(l_freq=40, h_freq=300, fir_design='firwin')
    raw.resample(sfreq=100)

    # wavelet frequencies after resample
    freqs = np.linspace(5, 50, 40)

    # Get ECoG data in NumPy format
    X = raw.get_data()  # Shape: (62, time_steps)

    # Reshape for MNE input: (1, channels, time_steps)
    X = X[np.newaxis, :, :]

    # Compute spectrogram using Morlet wavelet transform
    X_spectrogram = tfr_array_morlet(
        X, sfreq=100, freqs=freqs, output="power"
    )  # Shape: (1, 62, 20, time_steps)

    # Remove batch dimension
    X_spectrogram = X_spectrogram[0]  # Shape: (62, 20, time_steps)

    num_channels, num_freqs, num_time_steps = X_spectrogram.shape
    X_spectrogram_reshaped = X_spectrogram.reshape(num_channels, -1)  # Flatten freq-time

    scaler = MinMaxScaler(feature_range=(0, 1))

    X_spectrogram_scaled = scaler.fit_transform(X_spectrogram_reshaped.T).T  # Scale per channel
    X_spectrogram_scaled = X_spectrogram_scaled.reshape(num_channels, num_freqs, num_time_steps)

    print("Scaled Spectrogram Shape:", X_spectrogram_scaled.shape)

    return X_spectrogram

In [63]:
def crop_for_time_delay(finger_flex, spectrogramms, time_delay_sec=time_delay_secs, sample_rate=SAMPLE_RATE):

    time_delay = int(time_delay_sec*sample_rate)
    # The first time_delay seconds of finger flex data have no corresponding spectrograms
    finger_flex_cropped = finger_flex[..., time_delay:]
    # The last time_delay seconds of spectrograms have no corresponding finger flex data
    spectrogramms_cropped = spectrogramms[..., :spectrogramms.shape[2] - time_delay]
    return finger_flex_cropped, spectrogramms_cropped

In [None]:
data = scipy.io.loadmat('BCI_Competion4_dataset4_data_fingerflexions/sub1_comp.mat')

finger_flex = data['train_dg'].astype('float32').T
finger_flex = finger_flex_preprocess(finger_flex)
print("Shape:", finger_flex.shape)

In [None]:
train = data['train_data'].astype('float32').T
X_spectrogram = ecog_preprocess(train)

In [None]:
finger_flex_cropped, X_spectrogram_cropped = crop_for_time_delay(finger_flex, X_spectrogram)

print(finger_flex_cropped.shape, X_spectrogram_cropped.shape)

In [None]:
data_val = scipy.io.loadmat('BCI_Competion4_dataset4_data_fingerflexions/sub1_testlabels.mat')
finger_flex_val = data_val['test_dg'].astype('float32').T
finger_flex_val = finger_flex_preprocess(finger_flex_val)
print("Shape:", finger_flex_val.shape)

In [None]:
val = data['test_data'].astype('float32').T
X_spectrogram_val = ecog_preprocess(val)

In [None]:
finger_flex_cropped_val, X_spectrogram_cropped_val = crop_for_time_delay(finger_flex_val, X_spectrogram_val)

print(finger_flex_cropped_val.shape, X_spectrogram_cropped_val.shape)

In [81]:
class EcogDataset(Dataset):
    def __init__(self, data, data_length):
        self.spectrogram, self.finger_flex = data
        self.spectrogram = self.spectrogram.astype('float32')
        self.finger_flex = self.finger_flex.astype('float32')
        self.data_length = data_length

    def __len__(self):
        return self.spectrogram.shape[2] - self.data_length

    def __getitem__(self, index):
        spectrogram_crop = self.spectrogram[...,index:index+self.data_length]
        finger_flex_crop = self.finger_flex[...,index:index+self.data_length]
        return spectrogram_crop, finger_flex_crop

In [None]:
data_length = 256

train_dataset = EcogDataset((X_spectrogram_cropped, finger_flex_cropped), data_length)
print(X_spectrogram_cropped.shape, finger_flex_cropped.shape)
print(len(train_dataset))

val_dataset = EcogDataset((X_spectrogram_cropped_val, finger_flex_cropped_val), data_length)
print(X_spectrogram_cropped_val.shape, finger_flex_cropped_val.shape)
print(len(val_dataset))

In [83]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)