# Epoch Extraction and Feature Computation for P300 Analysis

This set of functions and the accompanying script handle the extraction of epochs from continuous EEG data, baseline-correct each epoch, flag bad epochs based on an amplitude threshold, and then extract features from predefined time ranges. The processed features, along with trial labels, are saved for subsequent analyses.

> **Note:** Hyperparamter: rejection_threshold = 40e-6

---

## **1. Epoch Extraction and Baseline Correction**

### **Function: `extract_epochs`**
- **Purpose:**  
  Extracts time- (& phase-) locked epochs from continuous EEG data based on a specified time window and sliding window parameters. It baseline-corrects each epoch and logs bad epochs if the amplitude range exceeds a set threshold.
- **Key Parameters:**
  - `X`: EEG data of shape *(n_trials, n_channels, n_samples)*.
  - `start_idx` and `end_idx`: Define the segment over which epochs are extracted (e.g., starting at sample 120 and ending at sample 2520).
  - `step_size`: Determines how far to slide the window for each new epoch (e.g., every 30 samples, corresponding to 250 ms at 120 Hz).
  - `start_sample_offset` and `end_sample_offset`: Define the relative window (e.g., from -24 to +84 samples, corresponding to -200 ms to 700 ms relative to the window start).
  - `amplitude_threshold`: Epochs with an amplitude range above this threshold are flagged as “bad”.
- **Output:**
  - A 4D matrix of shape *(n_trials, n_epochs, n_channels, window_size)* containing baseline-corrected epochs.
  - A list of indices identifying the bad epochs.

---

## **2. Feature Extraction**

### **Function: `extract_features_from_X`**
- **Purpose:**  
  Computes summary features for each epoch by averaging (or obtaining another statistic over) specified time ranges. These ranges are defined by the user as tuples (start, end indices) within the epoch.
- **Key Parameters:**
  - `X_matrix`: A 4D array of extracted epochs *(n_trials, n_epochs, n_channels, n_samples)*.
  - `ToI`: A list of time-of-interest tuples that define the segments from which features are extracted.
- **Output:**
  - A feature matrix of shape *(n_trials, n_epochs, n_channels, len(ToI))* containing the computed feature (in this case, mean values over the given time ranges).

---

## **3. Marking Bad Epochs**

### **Function: `mark_bad_epochs`**
- **Purpose:**  
  Flags epochs that have been identified as “bad” (e.g., due to excessive amplitude range) by replacing the data and corresponding labels with a sentinel value (`np.nan`).
- **Key Parameters:**
  - `X` and `z`: The input EEG epochs and their associated labels.
  - `bad_idx`: A list of tuples (trial, epoch, channel) marking the epochs that should be flagged.
- **Output:**
  - Modified EEG data and labels with bad epochs marked.

---

## **4. Script Overview**

### **Main Processing Script:**
- **Data Loading:**
  - For each subject, the script loads a preprocessed `.npz` file containing the EEG data (`X`), trial labels (`y`), additional labels (`z`), and the sampling frequency (`fs`).
  - The file path is built using subject-specific directories, ensuring that only valid files are processed.

- **Epoch Extraction:**
  - The function `extract_epochs` is called to segment the continuous data into epochs.  
  - Epochs are extracted using a sliding window that starts every 30 samples (250 ms) and spans 108 samples (from -200 ms to 700 ms relative to the window start).
  - Any epochs where the amplitude range exceeds `40e-6` are flagged as bad (their indices are stored in `bad_idx`).

- **Feature Computation:**
  - A list of time ranges (`ToI`) is defined. For example, these could represent meaningful intervals within an epoch.
  - The function `extract_features_from_X` computes features (e.g., mean values) within each of these time segments across all epochs, channels, and trials.
  
- **Epoch Marking:**
  - Rejected epochs are flagged by a sentinel value (`np.nan`) for downstream rejection.  
  - The key here is to preserve the trial structure while ensuring that bad data is identified.

- **Saving the Output:**
  - The resulting feature matrix, along with the trial labels (`y` and `z`) and the sampling frequency (`fs`), is saved into a new `.npz` file in a subject-specific directory.
  - This output can be used in further decoding or statistical analysis pipelines.



## Define Helper functions

Given time window:

Time window = [-200 ms, 700 ms] → corresponds to sample indices [-24, 84] at 120 Hz.
Each segment is 108 samples long (84 - (-24) = 108).

Sliding window every 250 ms:

Every 250 ms is equivalent to 30 samples.
Windows extracted time sample 0, 30, 60, ..., up to 2520.

Resulting matrix:

For each channel, multiple 109-sample time windows are extracted.
Repeat this process for all channels, yielding a matrix of dimensions (n_channels, n_windows, 108 samples).
The number of windows would be 80, based on starting from sample 0 and sliding every 30 samples up to 2520. This sanity check aligns with the notion 80 epochs.

In [1]:
def extract_epochs(X, start_idx=120, end_idx=2520, step_size=30, 
                   start_sample_offset=-24, end_sample_offset=84, 
                   amplitude_threshold=40e-6):
    """
    Function to extract epochs from time-series data for ERP features, 
    baseline-correct each epoch, and identify bad epochs based on amplitude threshold.
    
    Parameters:
    - X: Input data array of shape (n_trials, n_channels, n_samples)
    - start_idx: The starting sample index for the first epoch (default=120)
    - end_idx: The last sample index where the final epoch starts (default=2520)
    - step_size: Step size in samples, corresponding to the sliding window (default=30)
    - start_sample_offset: The offset for the start of the time window (default=-24, corresponds to -200 ms)
    - end_sample_offset: The offset for the end of the time window (default=84, corresponds to 700 ms)
    - amplitude_threshold: Threshold for identifying bad epochs based on amplitude range (default=100)
    
    Returns:
    - output_matrix: A 4D array of extracted and baseline-corrected epochs of shape 
                     (n_trials, n_epochs, n_channels, window_size)
    - bad_epochs_idx: List of indices of bad epochs for each trial and channel 
                      where amplitude range exceeds the threshold.
    """
    # Check input dimensions
    if X.ndim != 3:
        raise ValueError(f"Input X must have 3 dimensions (n_trials, n_channels, n_samples), but got {X.ndim} dimensions.")
    
    n_trials, n_channels, n_samples = X.shape
    window_size = end_sample_offset + np.abs(start_sample_offset)  # 108 samples
    epoch_timestamps = np.arange(start_idx, end_idx, step_size)    # (80,)
    n_epochs = len(epoch_timestamps)
    
    # Initialize the output matrix for the epochs and a list for bad epoch indices
    output_matrix = np.zeros((n_trials, n_epochs, n_channels, window_size))
    bad_epochs_idx = []  # To store (trial, epoch, channel) indices of bad epochs
    
    # Loop over trials, channels, and epochs to extract and baseline-correct the windows
    for i_trial in range(n_trials):
        for i_channel in range(n_channels):
            data = X[i_trial, i_channel, :]

            for i_epoch, t in enumerate(epoch_timestamps):
                epoch_start_idx = t + start_sample_offset  # Start at t - 24 samples (-200 ms)
                epoch_end_idx = t + end_sample_offset      # End at t + 84 samples (700 ms)
                
                # Ensure the window stays within bounds
                if epoch_start_idx >= 0 and epoch_end_idx <= n_samples:
                    epoch_data = data[epoch_start_idx:epoch_end_idx]
                    
                    # Baseline correction
                    baseline_mean = np.mean(epoch_data[:25])
                    epoch_data = epoch_data - baseline_mean
                    
                    # Store the epoch in the output matrix
                    output_matrix[i_trial, i_epoch, i_channel, :] = epoch_data
                    
                    # Check amplitude range after baseline subtraction
                    min_amp, max_amp = np.min(epoch_data), np.max(epoch_data)
                    amplitude_range = max_amp - min_amp
                    
                    # Log bad epochs if amplitude range exceeds threshold
                    if amplitude_range > amplitude_threshold:
                        bad_epochs_idx.append((i_trial, i_epoch, i_channel))
    
    # Return the 4D output matrix and the indices of bad epochs
    return output_matrix, bad_epochs_idx


### Build feature Vectors

In [2]:
def extract_features_from_X(X_matrix, ToI = ToI):
    """
    Extracts the maximum amplitudes from specified time ranges for each trial, epoch, and channel in the input data.

    Parameters:
    - X_matrix: A 4D numpy array of shape (n_trials, n_epochs, n_channels, n_samples) representing the input data.
    - ToI: A list of tuples, where each tuple contains the start and end indices of a time range of interest.

    Returns:
    - feature_matrix: A 4D numpy array of shape (n_trials, n_epochs, n_channels, len(ToI)) containing the maximum
                      values from the specified time ranges for each trial, epoch, and channel.
    """
    # Extract the shape of the input matrix
    n_trials, n_epochs, n_channels, n_samples = X_matrix.shape 
    
    # Initialize the feature matrix to store maximum values for each time range
    feature_matrix = np.zeros((n_trials, n_epochs, n_channels, len(ToI)))

    # Loop over the time ranges (ToI) and extract the max value for each range
    for i_range, (start, end) in enumerate(ToI):
        # For each time range, find the maximum values along the last axis (time samples) in the specified range
        feature_matrix[ :, :, :, i_range] = np.mean((X_matrix[ :, :, :, start:end]), axis=-1)

    # Return the feature matrix
    return feature_matrix


In [4]:
def mark_bad_epochs(X, z, bad_idx):
    """
    Marks bad epochs in both EEG data (X) and labels (z) by setting them to NaN (or another sentinel, ie -1).

    Parameters
    ----------
    X : ndarray
        4D array of shape (n_trials, n_epochs, n_channels, n_timepoints).
    z : ndarray
        3D array of shape (n_trials, n_epochs, label_dim).
    bad_idx : list of tuples
        List of (trial, epoch, channel) indices indicating bad epochs.

    Returns
    -------
    X_marked : ndarray
        Same shape as X, with bad epochs set to NaN (or a chosen sentinel).
    z_marked : ndarray
        Same shape as z, with bad epochs set to NaN (or a chosen sentinel).
    """
    # Convert list of (trial, epoch, channel) to a set of (trial, epoch) pairs
    bad_trial_epoch_pairs = set((trial, epoch) for trial, epoch, _ in bad_idx)

    # Make copies so we don't overwrite the original arrays
    X_marked = np.copy(X)
    z_marked = np.copy(z).astype(np.float64)

    # Mark each bad epoch in both X and z
    for trial_idx, epoch_idx in bad_trial_epoch_pairs:
        X_marked[trial_idx, epoch_idx, :, :] = np.nan 
        z_marked[trial_idx, epoch_idx, :]    = np.nan 

    return X_marked, z_marked


### Execute

In [6]:
import numpy as np
import os
from os.path import join
import numpy as np
import mne
import matplotlib.pyplot as plt
mne.set_log_level('warning')
task = "covert"   
# Define directories
wd = r'C:\Users\Radovan\OneDrive\Radboud\Studentships\Jordy Thielen\root'
os.chdir(wd)
data_dir = join(wd, "data")
experiment_dir = join(data_dir, "experiment")
files_dir = join(experiment_dir, 'files')
sourcedata_dir = join(experiment_dir, 'sourcedata')
derivatives_dir = join(join(experiment_dir, 'derivatives'))

subjects = ["VPpdia", "VPpdib", "VPpdic", "VPpdid", "VPpdie", "VPpdif", "VPpdig", "VPpdih", "VPpdii", "VPpdij",
            "VPpdik", "VPpdil", "VPpdim", "VPpdin", "VPpdio", "VPpdip", "VPpdiq", "VPpdir", "VPpdis", "VPpdit",
            "VPpdiu", "VPpdiv", "VPpdiw", "VPpdix", "VPpdiy", "VPpdiz", "VPpdiza", "VPpdizb", "VPpdizc"]

rejection_threshold = 40e-6
task = "covert"  
for subject in subjects:
    # Load the NPZ file
    file_dir = os.path.join(derivatives_dir, 'preprocessed', "p300", f"sub-{subject}")
    file_path = os.path.join(file_dir, f"sub-{subject}_task-{task}_p300.npz") 

    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        continue
    
    npz_data = np.load(file_path)

    # Extract data from the npz object
    X = npz_data['X']  # EEG data: trials x channels x samples
    y = npz_data['y']  # Labels indicating cued side: trials
    z = npz_data['z']  # Left and right targets: trials x epochs x sides
    fs = npz_data['fs']  # Sampling frequency

    X_matrix, bad_idx = extract_epochs(X, amplitude_threshold=rejection_threshold)
    '''
    Tag bad epochs by setting the values of above-threshold values to NaN
    Either use marking like this or alternatively remove it at this stage. Key to retain trial structure!!
    Inhomogenity in epoch-dimension size lead to error when constructing np-based matrix, cannot flatten trials & epoch dimenstions as this discards trial strcture 
    Solution: Build list of arrays and save individually. Decoding needs to be adapted for this data structure.
    '''
    #X_matrix, z = mark_bad_epochs(X_matrix, z, bad_idx)
    
    # Define periods for feature extraction
    ToI = [(30, 38), (38, 48), (48, 57), (57, 69), (69, 87), (87, 108)]

    feature_matrix = extract_features_from_X(X_matrix, ToI)

    # Save in one NPZ object
    save_dir = os.path.join(derivatives_dir, "features", "p300", f"sub-{subject}")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # .._noreject
    # ..features.npz
    np.savez(os.path.join(save_dir, f"sub-{subject}_task-{task}_p300_features_noreject.npz"), X=feature_matrix, y=y, z=z, fs=fs)