# Parkinson's Disease EEG Data Preprocessing Pipeline

This notebook implements a **memory-efficient** preprocessing pipeline for **combining two EEG datasets**:
1. **PD_Dataset_timing**: BrainVision format (.vhdr, .eeg, .vmrk) - 129 subjects
2. **ds004584-download**: BIDS format with EEGLAB files (.set, .fdt) - 149 subjects

## Important: Combined Dataset Processing
**Both datasets are processed together** in the same pipeline with **uniform preprocessing steps**. This ensures:
- Consistent preprocessing across all subjects regardless of source dataset
- Same parameters applied to both datasets (resampling, filtering, re-referencing, etc.)
- Combined output ready for cross-dataset analysis
- Dataset source is tracked in metadata but doesn't affect processing

## Pipeline Steps:
1. Data Loading from both datasets (one subject at a time)
2. Standardized Signal Pre-processing (Resampling, Filtering, Re-referencing)
3. Artifact Detection & Correction (ICA without ICLabel, Bad Channel Interpolation)
4. Subject-Level Alignment (Z-Score Normalization, Riemannian Re-centering)
5. Save preprocessed data for each subject

## Memory-Efficient Design:
- Processes one subject at a time to avoid memory issues
- Saves each preprocessed subject immediately
- Can resume from saved checkpoints
- No ICLabel dependency (uses correlation-based artifact detection)
- Robust error handling for NaN/Inf values and numerical issues


### Required Packages Installation

Before running this notebook, install the required packages:

```bash
pip install mne mne-bids pyriemann scipy numpy pandas
```

**Note**: This notebook does NOT require ICLabel or GPU. It uses memory-efficient, CPU-based processing.


In [8]:
# Import required libraries
import os
import numpy as np
import pandas as pd
import mne
from mne.io import read_raw_brainvision, read_raw_eeglab
from mne.preprocessing import ICA
from scipy import stats
from scipy.signal import find_peaks
import warnings
warnings.filterwarnings('ignore')

# Set MNE logging level to reduce output
mne.set_log_level('ERROR')

# Try to import pyriemann for Riemannian re-centering (optional)
try:
    from pyriemann.utils.mean import mean_riemann
    HAS_PYRIEMANN = True
except ImportError:
    HAS_PYRIEMANN = False
    print("Warning: pyriemann not available. Riemannian re-centering will be skipped.")

print("Libraries imported successfully!")
print(f"MNE version: {mne.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"PyRiemann available: {HAS_PYRIEMANN}")


Libraries imported successfully!
MNE version: 1.11.0
NumPy version: 2.0.2
PyRiemann available: True


## 1. Configuration and Paths

Define paths, preprocessing parameters, and output directory.


In [9]:
# Define dataset paths
BASE_PATH = r"C:\Users\Usha Sri\OneDrive\Documents\Parkinson_Project"
PD_DATASET_PATH = os.path.join(BASE_PATH, "PD_Dataset_timing")
BIDS_DATASET_PATH = os.path.join(BASE_PATH, "ds004584-download")
OUTPUT_DIR = os.path.join(BASE_PATH, "preprocessed_data")

# Preprocessing parameters
TARGET_SFREQ = 250  # Target sampling rate (Hz)
LOW_FREQ = 0.5      # Low cutoff frequency (Hz)
HIGH_FREQ = 50      # High cutoff frequency (Hz)

# ICA parameters (reduced for memory efficiency)
ICA_N_COMPONENTS = 15  # Reduced from 20 for memory efficiency
ICA_MAX_ITER = 200     # Reduced iterations for speed

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"PD Dataset Path: {PD_DATASET_PATH}")
print(f"BIDS Dataset Path: {BIDS_DATASET_PATH}")
print(f"Output Directory: {OUTPUT_DIR}")

# Verify paths exist
assert os.path.exists(PD_DATASET_PATH), f"PD dataset path not found: {PD_DATASET_PATH}"
assert os.path.exists(BIDS_DATASET_PATH), f"BIDS dataset path not found: {BIDS_DATASET_PATH}"
print("✓ Both dataset paths verified!")
print(f"\nPreprocessing Parameters:")
print(f"  Target Sampling Rate: {TARGET_SFREQ} Hz")
print(f"  Band-pass Filter: {LOW_FREQ} - {HIGH_FREQ} Hz")
print(f"  ICA Components: {ICA_N_COMPONENTS}")


PD Dataset Path: C:\Users\Usha Sri\OneDrive\Documents\Parkinson_Project\PD_Dataset_timing
BIDS Dataset Path: C:\Users\Usha Sri\OneDrive\Documents\Parkinson_Project\ds004584-download
Output Directory: C:\Users\Usha Sri\OneDrive\Documents\Parkinson_Project\preprocessed_data
✓ Both dataset paths verified!

Preprocessing Parameters:
  Target Sampling Rate: 250 Hz
  Band-pass Filter: 0.5 - 50 Hz
  ICA Components: 15


In [10]:
# Function to get file list (without loading data)
def get_brainvision_file_list(dataset_path):
    """Get list of BrainVision files without loading them."""
    vhdr_files = [f for f in os.listdir(dataset_path) if f.endswith('.vhdr')]
    vhdr_files.sort()
    
    file_list = []
    for vhdr_file in vhdr_files:
        file_path = os.path.join(dataset_path, vhdr_file)
        subject_id = vhdr_file.replace('.vhdr', '')
        group = 'Control' if 'control' in subject_id.lower() else 'PD'
        file_list.append({
            'subject_id': subject_id,
            'group': group,
            'file_path': file_path,
            'dataset': 'PD_Dataset_timing'
        })
    
    return file_list

def get_bids_file_list(dataset_path):
    """Get list of BIDS files without loading them."""
    subject_dirs = [d for d in os.listdir(dataset_path) if d.startswith('sub-')]
    subject_dirs.sort()
    
    file_list = []
    for subject_dir in subject_dirs:
        eeg_path = os.path.join(dataset_path, subject_dir, 'eeg')
        if not os.path.exists(eeg_path):
            continue
        
        set_files = [f for f in os.listdir(eeg_path) if f.endswith('.set')]
        if not set_files:
            continue
        
        set_file = os.path.join(eeg_path, set_files[0])
        file_list.append({
            'subject_id': subject_dir,
            'group': 'PD',
            'file_path': set_file,
            'dataset': 'ds004584-download'
        })
    
    return file_list

# Get file lists (memory efficient - no data loaded yet)
print("Scanning PD_Dataset_timing (BrainVision format)...")
pd_file_list = get_brainvision_file_list(PD_DATASET_PATH)

print("Scanning ds004584-download (BIDS format)...")
bids_file_list = get_bids_file_list(BIDS_DATASET_PATH)

# Combine file lists
all_file_list = pd_file_list + bids_file_list

print(f"\n✓ Found {len(pd_file_list)} files in PD_Dataset_timing")
print(f"✓ Found {len(bids_file_list)} files in ds004584-download")
print(f"✓ Total: {len(all_file_list)} files to process")


Scanning PD_Dataset_timing (BrainVision format)...
Scanning ds004584-download (BIDS format)...

✓ Found 129 files in PD_Dataset_timing
✓ Found 149 files in ds004584-download
✓ Total: 278 files to process


## 2. Preprocessing Functions

Define memory-efficient preprocessing functions that process one subject at a time.


In [11]:
def load_single_file(file_info):
    """Load a single file based on its format."""
    file_path = file_info['file_path']
    
    if file_info['dataset'] == 'PD_Dataset_timing':
        # BrainVision format
        raw = read_raw_brainvision(file_path, preload=False)
    else:
        # EEGLAB format
        raw = read_raw_eeglab(file_path, preload=False)
    
    return raw

def detect_bad_channels_simple(raw):
    """
    Simple bad channel detection using variance and flat signal detection.
    Memory-efficient alternative to find_bad_channels.
    Very conservative approach to avoid false positives.
    """
    raw_copy = raw.copy()
    raw_copy.load_data()
    
    bads = []
    data = raw_copy.get_data()
    
    # Check for infs and NaNs first - these are definitely bad
    for i, ch_name in enumerate(raw_copy.ch_names):
        ch_data = data[i, :]
        if np.any(np.isnan(ch_data)) or np.any(np.isinf(ch_data)):
            bads.append(ch_name)
    
    # Detect flat channels (very low variance) - very conservative threshold
    channel_vars = np.var(data, axis=1)
    # Use a very strict absolute threshold - only mark channels that are truly flat
    flat_threshold = 1e-8  # Very strict threshold
    
    for i, ch_var in enumerate(channel_vars):
        if ch_var < flat_threshold and raw_copy.ch_names[i] not in bads:
            bads.append(raw_copy.ch_names[i])
    
    # Detect channels with extreme values (likely disconnected) - very conservative
    channel_max = np.max(np.abs(data), axis=1)
    channel_median = np.median(channel_max)
    
    # Only mark channels that are EXTREMELY different (20x median or more)
    # This is much more conservative than before
    extreme_threshold = channel_median * 20
    
    for i, ch_max in enumerate(channel_max):
        if ch_max > extreme_threshold and raw_copy.ch_names[i] not in bads:
            bads.append(raw_copy.ch_names[i])
    
    # Very strict limit: max 5% of channels (or at least 1 if we have very few channels)
    max_bad_channels = max(1, min(5, int(len(raw_copy.ch_names) * 0.05)))
    if len(bads) > max_bad_channels:
        # If too many, only keep the most extreme ones based on variance
        bad_scores = []
        for ch_name in bads:
            ch_idx = raw_copy.ch_names.index(ch_name)
            # Score based on how bad the channel is (lower variance = worse)
            var_score = 1.0 / (channel_vars[ch_idx] + 1e-12)
            bad_scores.append((ch_name, var_score))
        
        bad_scores.sort(key=lambda x: x[1], reverse=True)
        bads = [ch for ch, _ in bad_scores[:max_bad_channels]]
    
    return bads

def preprocess_raw(raw, subject_id):
    """
    Apply standardized preprocessing to a Raw object.
    Memory-efficient version.
    """
    # Load data into memory
    raw.load_data()
    
    # Step 1: Resampling to target sampling rate
    if raw.info['sfreq'] != TARGET_SFREQ:
        raw.resample(TARGET_SFREQ, npad='auto')
    
    # Step 2: Band-pass filtering (0.5 - 50 Hz)
    raw.filter(LOW_FREQ, HIGH_FREQ, fir_design='firwin', 
               skip_by_annotation='edge', verbose=False)
    
    # Step 3: Common Average Reference (CAR)
    raw.set_eeg_reference('average', projection=False, verbose=False)
    
    return raw

def detect_artifacts_ica_simple(raw, subject_id):
    """
    Simple ICA-based artifact detection without ICLabel.
    Uses correlation-based detection for EOG/ECG artifacts.
    Includes data cleaning to handle NaN/Inf values.
    """
    # Make a copy for ICA
    raw_ica = raw.copy()
    
    # Check for and clean NaN/Inf values before ICA
    data = raw_ica.get_data()
    if np.any(np.isnan(data)) or np.any(np.isinf(data)):
        print(f"    Warning: Found NaN/Inf values, cleaning data...")
        # Replace NaN/Inf with zeros (or median of channel)
        for i in range(data.shape[0]):
            ch_data = data[i, :]
            if np.any(np.isnan(ch_data)) or np.any(np.isinf(ch_data)):
                # Replace with median of valid values
                valid_data = ch_data[np.isfinite(ch_data)]
                if len(valid_data) > 0:
                    replacement = np.median(valid_data)
                else:
                    replacement = 0.0
                ch_data[np.isnan(ch_data) | np.isinf(ch_data)] = replacement
                data[i, :] = ch_data
        raw_ica._data = data
    
    # High-pass filter at 1 Hz for ICA (recommended)
    raw_ica.filter(1., None, fir_design='firwin', 
                   skip_by_annotation='edge', verbose=False)
    
    # Check again after filtering
    data = raw_ica.get_data()
    if np.any(np.isnan(data)) or np.any(np.isinf(data)):
        print(f"    Warning: NaN/Inf values after filtering, skipping ICA...")
        return raw, []
    
    # Determine number of components (use fewer for memory efficiency)
    n_channels = len(raw_ica.ch_names)
    n_components = min(ICA_N_COMPONENTS, n_channels - 1)
    
    if n_components < 2:
        return raw, []  # Not enough channels for ICA
    
    # Fit ICA with reduced iterations
    ica = ICA(n_components=n_components, random_state=42, 
              max_iter=ICA_MAX_ITER, verbose=False)
    
    try:
        # Fit ICA with decimation for speed
        ica.fit(raw_ica, decim=5, verbose=False)
        
        # Find artifacts using correlation with EOG/ECG patterns
        exclude_idx = []
        
        # Try to find EOG artifacts (if EOG channels exist)
        eog_channels = [ch for ch in raw_ica.ch_names if 'EOG' in ch.upper() or 'Fp' in ch]
        if len(eog_channels) > 0:
            try:
                eog_indices, eog_scores = ica.find_bads_eog(raw_ica, 
                                                             ch_name=eog_channels[0],
                                                             verbose=False)
                exclude_idx.extend(eog_indices)
            except:
                pass
        
        # Try to find ECG artifacts
        try:
            ecg_indices, ecg_scores = ica.find_bads_ecg(raw_ica, verbose=False)
            exclude_idx.extend(ecg_indices)
        except:
            pass
        
        # Additional heuristic: components with high kurtosis (muscle artifacts)
        # But be more conservative - only top 5% most extreme
        try:
            if len(ica.mixing_matrix_) > 0:
                sources = ica.get_sources(raw_ica).get_data()
                # Check for NaN/Inf in sources
                if np.any(np.isnan(sources)) or np.any(np.isinf(sources)):
                    print(f"    Warning: NaN/Inf in ICA sources, skipping kurtosis detection...")
                else:
                    kurtosis_vals = stats.kurtosis(sources, axis=1)
                    # Only mark top 5% most extreme (was 10%)
                    high_kurtosis = np.where(kurtosis_vals > np.percentile(kurtosis_vals, 95))[0]
                    exclude_idx.extend(high_kurtosis.tolist())
        except:
            pass
        
        # Remove duplicates
        exclude_idx = list(set(exclude_idx))
        exclude_idx = [idx for idx in exclude_idx if idx < n_components]
        
        # Apply ICA if artifacts found
        if len(exclude_idx) > 0:
            ica.exclude = exclude_idx
            raw_cleaned = raw.copy()
            ica.apply(raw_cleaned, verbose=False)
            
            # Final check for NaN/Inf after ICA
            data_cleaned = raw_cleaned.get_data()
            if np.any(np.isnan(data_cleaned)) or np.any(np.isinf(data_cleaned)):
                print(f"    Warning: NaN/Inf values after ICA, reverting to original...")
                return raw, []
            
            return raw_cleaned, exclude_idx
        else:
            return raw, []
            
    except Exception as e:
        print(f"    Warning: ICA failed for {subject_id}: {str(e)}")
        return raw, []
    
def zscore_normalize(raw):
    """Apply Z-score normalization per channel."""
    raw_normalized = raw.copy()
    data = raw_normalized.get_data()
    
    # Z-score normalization per channel
    data_normalized = stats.zscore(data, axis=1)
    raw_normalized._data = data_normalized
    
    return raw_normalized

def riemannian_recenter(raw, subject_id):
    """
    Apply Riemannian re-centering if pyriemann is available.
    Memory-efficient version using smaller windows with stronger regularization.
    """
    if not HAS_PYRIEMANN:
        return raw
    
    try:
        data = raw.get_data().T  # Shape: (n_samples, n_channels)
        
        # Check for NaN/Inf
        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            print(f"    Warning: NaN/Inf values detected, skipping Riemannian re-centering...")
            return raw
        
        # Use smaller windows for memory efficiency
        window_length = int(raw.info['sfreq'] * 0.5)  # 0.5 second windows
        step_size = window_length // 2
        
        if data.shape[0] < window_length:
            return raw  # Not enough data
        
        covariances = []
        n_windows = min(50, (data.shape[0] - window_length) // step_size + 1)  # Limit windows
        
        for i in range(n_windows):
            start_idx = i * step_size
            if start_idx + window_length > data.shape[0]:
                break
            window_data = data[start_idx:start_idx + window_length, :]
            
            # Check for NaN/Inf in window
            if np.any(np.isnan(window_data)) or np.any(np.isinf(window_data)):
                continue  # Skip this window
            
            cov = np.cov(window_data.T)
            
            # Stronger regularization to ensure positive definiteness
            # Use larger regularization value
            reg_value = max(1e-4, np.trace(cov) / cov.shape[0] * 0.01)  # 1% of average diagonal
            cov += np.eye(cov.shape[0]) * reg_value
            
            # Verify positive definiteness
            eigenvals = np.linalg.eigvalsh(cov)
            if np.any(eigenvals <= 0):
                # If still not positive definite, add more regularization
                cov += np.eye(cov.shape[0]) * (abs(np.min(eigenvals)) + 1e-4)
            
            covariances.append(cov)
        
        if len(covariances) < 3:
            print(f"    Warning: Not enough valid covariance matrices, skipping...")
            return raw
        
        if len(covariances) > 0:
            covariances = np.array(covariances)
            
            # Additional check: ensure all covariances are positive definite
            for i, cov in enumerate(covariances):
                eigenvals = np.linalg.eigvalsh(cov)
                if np.any(eigenvals <= 0):
                    # Add more regularization
                    reg_value = abs(np.min(eigenvals)) + 1e-4
                    covariances[i] = cov + np.eye(cov.shape[0]) * reg_value
            
            mean_cov = mean_riemann(covariances)
            
            # Final check on mean covariance
            eigenvals = np.linalg.eigvalsh(mean_cov)
            if np.any(eigenvals <= 0):
                # Add regularization to mean
                reg_value = abs(np.min(eigenvals)) + 1e-4
                mean_cov += np.eye(mean_cov.shape[0]) * reg_value
                eigenvals = np.linalg.eigvalsh(mean_cov)
            
            # Whitening
            eigenvals = np.maximum(eigenvals, 1e-8)  # Stronger minimum threshold
            eigenvecs = np.linalg.eigh(mean_cov)[1]
            whitening_matrix = eigenvecs @ np.diag(1.0 / np.sqrt(eigenvals)) @ eigenvecs.T
            
            data_recentered = (whitening_matrix @ data.T).T
            
            # Final check for NaN/Inf
            if np.any(np.isnan(data_recentered)) or np.any(np.isinf(data_recentered)):
                print(f"    Warning: NaN/Inf after whitening, reverting...")
                return raw
            
            raw_recentered = raw.copy()
            raw_recentered._data = data_recentered.T
            
            return raw_recentered
        else:
            return raw
            
    except Exception as e:
        print(f"    Warning: Riemannian re-centering failed: {str(e)}")
        return raw

print("✓ Preprocessing functions defined")


✓ Preprocessing functions defined


## 3. Main Preprocessing Pipeline

Process each subject one at a time, save immediately to avoid memory issues.


In [16]:
def is_file_valid(file_path):
    """
    Check if a preprocessed file exists and is valid (can be loaded).
    
    Parameters:
    -----------
    file_path : str
        Path to the preprocessed file
        
    Returns:
    --------
    bool
        True if file exists and is valid, False otherwise
    """
    if not os.path.exists(file_path):
        return False
    
    try:
        # Try to load the file to verify it's valid
        raw = mne.io.read_raw_fif(file_path, preload=False, verbose=False)
        # Check if it has data
        if raw.n_times > 0 and len(raw.ch_names) > 0:
            return True
        else:
            return False
    except:
        # File exists but is corrupted or invalid
        return False

def process_single_subject(file_info, output_dir, skip_existing=True):
    """
    Process a single subject through the entire pipeline.
    Saves the result immediately to avoid memory issues.
    
    Parameters:
    -----------
    file_info : dict
        Dictionary with subject information
    output_dir : str
        Directory to save preprocessed data
    skip_existing : bool
        If True, skip subjects that are already processed
        
    Returns:
    --------
    success : bool
        Whether processing was successful
    """
    subject_id = file_info['subject_id']
    output_file = os.path.join(output_dir, f"{subject_id}_preprocessed.fif")
    
    # Check if already processed (verify file is valid, not just exists)
    if skip_existing:
        if is_file_valid(output_file):
            return True  # Return True without printing (will be counted in summary)
        elif os.path.exists(output_file):
            # File exists but is invalid - remove it and reprocess
            print(f"  {subject_id}: Found invalid file, will reprocess...")
            try:
                os.remove(output_file)
            except:
                pass
    
    try:
        print(f"\n{'='*60}")
        print(f"Processing: {subject_id} ({file_info['group']}) from {file_info['dataset']}")
        print(f"{'='*60}")
        
        # Step 1: Load file
        print(f"  Loading file...")
        raw = load_single_file(file_info)
        
        # Step 2: Standardized preprocessing
        print(f"  Step 1/5: Standardized preprocessing (resample, filter, CAR)...")
        raw = preprocess_raw(raw, subject_id)
        
        # Step 3: Bad channel detection and interpolation
        print(f"  Step 2/5: Bad channel detection...")
        bads = detect_bad_channels_simple(raw)
        if len(bads) > 0:
            print(f"    Found {len(bads)} bad channels: {bads[:10]}{'...' if len(bads) > 10 else ''}")
            raw.info['bads'] = bads
            
            # Try interpolation with error handling
            try:
                # Try standard interpolation first
                raw.interpolate_bads(reset_bads=True, verbose=False)
                interpolation_success = True
            except (np.linalg.LinAlgError, ValueError, RuntimeError) as e:
                # If interpolation fails (e.g., SVD convergence issues), try alternative methods
                print(f"    Warning: Standard interpolation failed ({type(e).__name__}), trying alternative...")
                interpolation_success = False
                
                # Try with explicit origin parameter
                try:
                    # Use a simple origin if head shape fitting fails
                    raw.interpolate_bads(reset_bads=True, origin=(0., 0., 0.), verbose=False)
                    interpolation_success = True
                except:
                    # If that also fails, try removing bad channels instead of interpolating
                    print(f"    Warning: Interpolation not possible, removing bad channels instead...")
                    try:
                        raw.drop_channels(bads, verbose=False)
                        print(f"    ✓ Bad channels removed (interpolation not possible)")
                        interpolation_success = True
                    except Exception as e2:
                        print(f"    Error: Could not remove bad channels: {str(e2)}")
                        # If removal also fails, just mark them as bad but continue
                        raw.info['bads'] = bads
                        interpolation_success = False
            
            if interpolation_success:
                # Check for NaN/Inf after interpolation
                data = raw.get_data()
                if np.any(np.isnan(data)) or np.any(np.isinf(data)):
                    print(f"    Warning: NaN/Inf values after interpolation, cleaning...")
                    # Clean NaN/Inf values
                    for i in range(data.shape[0]):
                        ch_data = data[i, :]
                        if np.any(np.isnan(ch_data)) or np.any(np.isinf(ch_data)):
                            valid_data = ch_data[np.isfinite(ch_data)]
                            if len(valid_data) > 0:
                                replacement = np.median(valid_data)
                            else:
                                replacement = 0.0
                            ch_data[np.isnan(ch_data) | np.isinf(ch_data)] = replacement
                            data[i, :] = ch_data
                    raw._data = data
                
                print(f"    ✓ Bad channels handled")
            else:
                print(f"    Warning: Could not interpolate or remove bad channels, continuing with bad channels marked...")
        else:
            print(f"    No bad channels detected")
        
        # Step 4: ICA artifact removal
        print(f"  Step 3/5: ICA artifact removal...")
        raw, excluded_ics = detect_artifacts_ica_simple(raw, subject_id)
        if len(excluded_ics) > 0:
            print(f"    ✓ Removed {len(excluded_ics)} artifact components")
        else:
            print(f"    No artifacts detected")
        
        # Step 5: Z-score normalization
        print(f"  Step 4/5: Z-score normalization...")
        raw = zscore_normalize(raw)
        
        # Step 6: Riemannian re-centering (optional)
        print(f"  Step 5/5: Riemannian re-centering...")
        raw = riemannian_recenter(raw, subject_id)
        
        # Save preprocessed data
        print(f"  Saving preprocessed data to: {output_file}")
        raw.save(output_file, overwrite=True, verbose=False)
        
        print(f"  ✓ Successfully processed and saved {subject_id}")
        return True
        
    except Exception as e:
        error_msg = str(e)
        error_type = type(e).__name__
        print(f"  ✗ Error processing {subject_id}: {error_type}: {error_msg}")
        # Store error info for later analysis
        if not hasattr(process_single_subject, 'error_log'):
            process_single_subject.error_log = []
        process_single_subject.error_log.append({
            'subject_id': subject_id,
            'error_type': error_type,
            'error_message': error_msg,
            'dataset': file_info.get('dataset', 'unknown')
        })
        return False

# Check existing files before processing
print("=" * 60)
print("Checking existing preprocessed files...")
print("=" * 60)

existing_valid = []
existing_invalid = []
need_processing = []

for file_info in all_file_list:
    subject_id = file_info['subject_id']
    output_file = os.path.join(OUTPUT_DIR, f"{subject_id}_preprocessed.fif")
    
    if is_file_valid(output_file):
        existing_valid.append(subject_id)
    elif os.path.exists(output_file):
        existing_invalid.append(subject_id)
        # Remove invalid file
        try:
            os.remove(output_file)
        except:
            pass
    else:
        need_processing.append(subject_id)

print(f"✓ Found {len(existing_valid)} already processed and valid files")
if len(existing_invalid) > 0:
    print(f"⚠ Found {len(existing_invalid)} invalid files (will be reprocessed)")
print(f"→ Need to process: {len(need_processing)} files")
print(f"  Total: {len(all_file_list)} files")

# Process all subjects
print("\n" + "=" * 60)
print("Starting preprocessing pipeline...")
print(f"Output directory: {OUTPUT_DIR}")
print("=" * 60)

# Initialize error log
process_single_subject.error_log = []

# Track progress
successful = existing_valid.copy()  # Start with already processed files
failed = []
failed_info = []  # Store detailed failure information
newly_processed = []

for i, file_info in enumerate(all_file_list, 1):
    subject_id = file_info['subject_id']
    
    # Only show progress for files that need processing
    if subject_id in need_processing:
        print(f"\n[{len(successful) + len(failed) + 1}/{len(all_file_list)}] Processing: {subject_id}")
    else:
        # Skip silently for already processed files
        continue
    
    success = process_single_subject(file_info, OUTPUT_DIR, skip_existing=True)
    
    if success:
        if subject_id not in existing_valid:
            newly_processed.append(subject_id)
            successful.append(subject_id)
        # If already in existing_valid, it's already counted in successful
    else:
        failed.append(subject_id)
        # Get error info if available
        if hasattr(process_single_subject, 'error_log') and process_single_subject.error_log:
            last_error = process_single_subject.error_log[-1]
            if last_error['subject_id'] == subject_id:
                failed_info.append(last_error)

# Save failed files list and error log
failed_file = os.path.join(OUTPUT_DIR, "failed_subjects.txt")
error_log_file = os.path.join(OUTPUT_DIR, "error_log.json")

if len(failed) > 0:
    # Save list of failed subjects
    with open(failed_file, 'w') as f:
        for subj in failed:
            f.write(f"{subj}\n")
    
    # Save detailed error log
    if failed_info:
        import json
        with open(error_log_file, 'w') as f:
            json.dump(failed_info, f, indent=2)
    
    print(f"\n✓ Saved failed subjects list to: {failed_file}")
    print(f"✓ Saved error log to: {error_log_file}")

print("\n" + "=" * 60)
print("Preprocessing Summary")
print("=" * 60)
print(f"Already processed (skipped): {len(existing_valid)}")
print(f"Newly processed: {len(newly_processed)}")
print(f"Failed: {len(failed)}")
print(f"Total successful: {len(successful)}/{len(all_file_list)}")

if len(failed) > 0:
    print(f"\nFailed subjects ({len(failed)}):")
    
    # Group by error type
    error_types = {}
    for err in failed_info:
        err_type = err.get('error_type', 'Unknown')
        if err_type not in error_types:
            error_types[err_type] = []
        error_types[err_type].append(err['subject_id'])
    
    print("\n  Error breakdown:")
    for err_type, subjects in error_types.items():
        print(f"    {err_type}: {len(subjects)} files")
        if len(subjects) <= 10:
            for subj in subjects:
                print(f"      - {subj}")
        else:
            for subj in subjects[:5]:
                print(f"      - {subj}")
            print(f"      ... and {len(subjects) - 5} more")
    
    print(f"\n  All failed subjects:")
    for subj in failed[:30]:  # Show first 30
        print(f"    - {subj}")
    if len(failed) > 30:
        print(f"    ... and {len(failed) - 30} more")
    
    print(f"\n  To retry failed files, see the next cell.")


Checking existing preprocessed files...
✓ Found 278 already processed and valid files
→ Need to process: 0 files
  Total: 278 files

Starting preprocessing pipeline...
Output directory: C:\Users\Usha Sri\OneDrive\Documents\Parkinson_Project\preprocessed_data

Preprocessing Summary
Already processed (skipped): 278
Newly processed: 0
Failed: 0
Total successful: 278/278


## 4. Retry Failed Files (Optional)

If you have failed files, you can retry processing them here. This cell will:
- Load the list of failed subjects
- Attempt to reprocess them
- Use the same error handling as the main pipeline


In [17]:
# Retry failed files
RETRY_FAILED = True  # Set to True to retry failed files

if RETRY_FAILED:
    failed_file = os.path.join(OUTPUT_DIR, "failed_subjects.txt")
    
    if os.path.exists(failed_file):
        # Load failed subjects
        with open(failed_file, 'r') as f:
            failed_subjects = [line.strip() for line in f if line.strip()]
        
        print("=" * 60)
        print(f"Retrying {len(failed_subjects)} failed subjects...")
        print("=" * 60)
        
        # Find file info for failed subjects
        failed_file_list = [f for f in all_file_list if f['subject_id'] in failed_subjects]
        
        if len(failed_file_list) == 0:
            print("No matching files found in file list.")
        else:
            # Initialize error log
            process_single_subject.error_log = []
            
            retry_successful = []
            retry_failed = []
            
            for i, file_info in enumerate(failed_file_list, 1):
                subject_id = file_info['subject_id']
                print(f"\n[{i}/{len(failed_file_list)}] Retrying: {subject_id}")
                
                # Remove existing file if it exists (might be corrupted)
                output_file = os.path.join(OUTPUT_DIR, f"{subject_id}_preprocessed.fif")
                if os.path.exists(output_file):
                    try:
                        os.remove(output_file)
                    except:
                        pass
                
                success = process_single_subject(file_info, OUTPUT_DIR, skip_existing=False)
                
                if success:
                    retry_successful.append(subject_id)
                else:
                    retry_failed.append(subject_id)
            
            print("\n" + "=" * 60)
            print("Retry Summary")
            print("=" * 60)
            print(f"Successfully retried: {len(retry_successful)}/{len(failed_file_list)}")
            print(f"Still failed: {len(retry_failed)}/{len(failed_file_list)}")
            
            if len(retry_successful) > 0:
                print(f"\n✓ Successfully retried:")
                for subj in retry_successful:
                    print(f"  - {subj}")
            
            if len(retry_failed) > 0:
                print(f"\n✗ Still failed:")
                for subj in retry_failed:
                    print(f"  - {subj}")
                
                # Update failed subjects file
                with open(failed_file, 'w') as f:
                    for subj in retry_failed:
                        f.write(f"{subj}\n")
                print(f"\n✓ Updated failed subjects list")
    else:
        print("No failed_subjects.txt file found. Run the main pipeline first.")
else:
    print("RETRY_FAILED is set to False. Set it to True to retry failed files.")


No failed_subjects.txt file found. Run the main pipeline first.


## 5. Analyze Failed Files

View detailed error information for failed files to understand what went wrong.


In [18]:
# Analyze error log
error_log_file = os.path.join(OUTPUT_DIR, "error_log.json")
failed_file = os.path.join(OUTPUT_DIR, "failed_subjects.txt")

if os.path.exists(error_log_file):
    import json
    
    with open(error_log_file, 'r') as f:
        error_log = json.load(f)
    
    print("=" * 60)
    print("Error Analysis")
    print("=" * 60)
    print(f"Total failed files: {len(error_log)}\n")
    
    # Group by error type
    error_types = {}
    for err in error_log:
        err_type = err.get('error_type', 'Unknown')
        if err_type not in error_types:
            error_types[err_type] = []
        error_types[err_type].append(err)
    
    print("Error breakdown by type:")
    for err_type, errors in sorted(error_types.items(), key=lambda x: len(x[1]), reverse=True):
        print(f"\n  {err_type}: {len(errors)} files")
        
        # Show sample error messages
        sample_errors = errors[:3]
        for err in sample_errors:
            msg = err.get('error_message', '')[:100]  # First 100 chars
            print(f"    - {err['subject_id']} ({err.get('dataset', 'unknown')}): {msg}...")
        if len(errors) > 3:
            print(f"    ... and {len(errors) - 3} more with similar errors")
    
    # Group by dataset
    print("\n" + "-" * 60)
    print("Error breakdown by dataset:")
    dataset_errors = {}
    for err in error_log:
        dataset = err.get('dataset', 'unknown')
        if dataset not in dataset_errors:
            dataset_errors[dataset] = []
        dataset_errors[dataset].append(err)
    
    for dataset, errors in sorted(dataset_errors.items(), key=lambda x: len(x[1]), reverse=True):
        print(f"  {dataset}: {len(errors)} files")
    
    # Common error patterns
    print("\n" + "-" * 60)
    print("Common error patterns:")
    error_messages = {}
    for err in error_log:
        msg = err.get('error_message', '')
        # Extract key part of message (first 50 chars)
        key = msg[:50] if len(msg) > 50 else msg
        if key not in error_messages:
            error_messages[key] = []
        error_messages[key].append(err['subject_id'])
    
    for msg, subjects in sorted(error_messages.items(), key=lambda x: len(x[1]), reverse=True)[:5]:
        print(f"  '{msg}...': {len(subjects)} files")
        if len(subjects) <= 5:
            print(f"    Subjects: {', '.join(subjects)}")
        else:
            print(f"    Subjects: {', '.join(subjects[:5])}... and {len(subjects) - 5} more")
    
    print("\n" + "=" * 60)
    print("Full error log saved to:", error_log_file)
    print("Failed subjects list saved to:", failed_file)
    
elif os.path.exists(failed_file):
    print("Error log not found, but failed subjects list exists.")
    with open(failed_file, 'r') as f:
        failed_subjects = [line.strip() for line in f if line.strip()]
    print(f"Failed subjects ({len(failed_subjects)}):")
    for subj in failed_subjects:
        print(f"  - {subj}")
else:
    print("No error log or failed subjects file found. All files processed successfully!")


No error log or failed subjects file found. All files processed successfully!


## 4. Load and Summarize Preprocessed Data

Load all preprocessed files and create summary statistics.


In [19]:
def zscore_normalize(raw, subject_id):
    """
    Apply Z-score normalization to each channel independently.
    This normalizes each subject's data to have zero mean and unit variance.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Artifact-corrected Raw object
    subject_id : str
        Subject identifier
        
    Returns:
    --------
    raw_normalized : mne.io.Raw
        Z-score normalized Raw object
    """
    print(f"  Z-score normalizing {subject_id}...")
    
    raw_normalized = raw.copy()
    data = raw_normalized.get_data()
    
    # Z-score normalization per channel
    # (data - mean) / std
    data_normalized = stats.zscore(data, axis=1)
    
    raw_normalized._data = data_normalized
    
    print(f"  ✓ Z-score normalization complete")
    
    return raw_normalized

def riemannian_recenter(raw, subject_id):
    """
    Apply Riemannian re-centering to align data distributions.
    This computes the geometric mean of covariance matrices and whitens the data.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Z-score normalized Raw object
    subject_id : str
        Subject identifier
        
    Returns:
    --------
    raw_recentered : mne.io.Raw
        Riemannian re-centered Raw object
    """
    print(f"  Riemannian re-centering {subject_id}...")
    
    try:
        # Extract data
        data = raw.get_data().T  # Shape: (n_samples, n_channels)
        
        # Compute covariance matrices using sliding windows
        # Use 1-second windows with 50% overlap
        window_length = int(raw.info['sfreq'])  # 1 second
        step_size = window_length // 2  # 50% overlap
        
        covariances = []
        for start_idx in range(0, data.shape[0] - window_length + 1, step_size):
            window_data = data[start_idx:start_idx + window_length, :]
            # Compute sample covariance
            cov = np.cov(window_data.T)
            # Regularize to ensure positive definiteness
            cov += np.eye(cov.shape[0]) * 1e-6
            covariances.append(cov)
        
        if len(covariances) > 0:
            covariances = np.array(covariances)
            
            # Compute geometric mean on Riemannian manifold
            mean_cov = mean_riemann(covariances)
            
            # Whiten the data using the geometric mean
            # Compute whitening matrix: W = mean_cov^(-1/2)
            eigenvals, eigenvecs = np.linalg.eigh(mean_cov)
            eigenvals = np.maximum(eigenvals, 1e-10)  # Avoid numerical issues
            whitening_matrix = eigenvecs @ np.diag(1.0 / np.sqrt(eigenvals)) @ eigenvecs.T
            
            # Apply whitening to all data
            data_recentered = (whitening_matrix @ data.T).T
            
            # Create new Raw object
            raw_recentered = raw.copy()
            raw_recentered._data = data_recentered.T
            
            print(f"  ✓ Riemannian re-centering complete")
        else:
            print(f"  Warning: Not enough data for Riemannian re-centering, skipping...")
            raw_recentered = raw.copy()
            
    except Exception as e:
        print(f"  Warning: Riemannian re-centering failed: {str(e)}")
        print(f"  Continuing without Riemannian re-centering...")
        raw_recentered = raw.copy()
    
    return raw_recentered

# Apply subject-level alignment
print("=" * 60)
print("Starting subject-level alignment...")
print("=" * 60)

aligned_raws = []
aligned_metadata = []

for raw, meta in zip(cleaned_raws, cleaned_metadata):
    try:
        # Step 1: Z-score normalization
        raw_normalized = zscore_normalize(raw, meta['subject_id'])
        
        # Step 2: Riemannian re-centering
        raw_aligned = riemannian_recenter(raw_normalized, meta['subject_id'])
        
        aligned_raws.append(raw_aligned)
        aligned_metadata.append(meta)
        
    except Exception as e:
        print(f"✗ Error in alignment for {meta['subject_id']}: {str(e)}")
        continue

print(f"\n✓ Successfully aligned {len(aligned_raws)}/{len(cleaned_raws)} files")


Starting subject-level alignment...


NameError: name 'cleaned_raws' is not defined

## 5. Verify Channel Consistency

Check that all subjects have the same channels (as specified in requirements).


In [20]:
# Check channel consistency
print("=" * 60)
print("Channel Consistency Check")
print("=" * 60)

# Load channel names from a few files to check consistency
channel_sets = []
sample_subjects = summary_df['subject_id'].head(10).tolist()

for subject_id in sample_subjects:
    file_path = summary_df[summary_df['subject_id'] == subject_id]['file_path'].iloc[0]
    try:
        raw = mne.io.read_raw_fif(file_path, preload=False, verbose=False)
        channel_sets.append(set(raw.ch_names))
    except:
        continue

if len(channel_sets) > 0:
    # Check if all have same channels
    unique_channel_sets = [list(s) for s in set(tuple(sorted(s)) for s in channel_sets)]
    
    if len(unique_channel_sets) == 1:
        print(f"✓ All checked subjects have the same channels: {len(unique_channel_sets[0])} channels")
        print(f"  Channels: {', '.join(sorted(unique_channel_sets[0])[:10])}...")
        COMMON_CHANNELS = sorted(unique_channel_sets[0])
    else:
        print(f"⚠ Warning: Found {len(unique_channel_sets)} different channel configurations")
        for i, ch_set in enumerate(unique_channel_sets):
            print(f"  Configuration {i+1}: {len(ch_set)} channels")
            print(f"    {', '.join(sorted(ch_set)[:10])}...")
        # Use the most common channel set
        from collections import Counter
        channel_counts = Counter(tuple(sorted(s)) for s in channel_sets)
        most_common = channel_counts.most_common(1)[0][0]
        COMMON_CHANNELS = sorted(list(most_common))
        print(f"\n  Using most common channel set: {len(COMMON_CHANNELS)} channels")
else:
    print("⚠ Could not verify channel consistency")
    COMMON_CHANNELS = None

# Verify sampling rate consistency
sampling_rates = summary_df['sampling_rate'].unique()
if len(sampling_rates) == 1:
    print(f"\n✓ All subjects have the same sampling rate: {sampling_rates[0]} Hz")
else:
    print(f"\n⚠ Warning: Found different sampling rates: {sampling_rates}")


Channel Consistency Check


NameError: name 'summary_df' is not defined

In [None]:
## 6. Extract Data Arrays for Machine Learning (Optional)

Extract data arrays from preprocessed files. This can be memory-intensive, so it's optional.


In [21]:
def extract_data_arrays_memory_efficient(summary_df, output_dir, max_subjects=None):
    """
    Extract data arrays from preprocessed files in a memory-efficient way.
    Only loads one subject at a time.
    
    Parameters:
    -----------
    summary_df : pd.DataFrame
        Summary dataframe with file paths
    output_dir : str
        Output directory for saving arrays
    max_subjects : int, optional
        Maximum number of subjects to process (for testing)
        
    Returns:
    --------
    X_file : str
        Path to saved X array file
    y_file : str
        Path to saved y array file
    subject_ids_file : str
        Path to saved subject_ids file
    """
    print("=" * 60)
    print("Extracting data arrays for machine learning...")
    print("=" * 60)
    print("Note: This may take a while and use significant memory.")
    print("      Consider processing in batches if memory is limited.\n")
    
    # Limit subjects if specified
    if max_subjects is not None:
        summary_df = summary_df.head(max_subjects)
        print(f"Processing first {max_subjects} subjects only...\n")
    
    X_list = []
    y_list = []
    subject_ids = []
    
    for idx, row in summary_df.iterrows():
        subject_id = row['subject_id']
        file_path = row['file_path']
        
        try:
            print(f"  Loading {subject_id}... ({idx+1}/{len(summary_df)})")
            raw = mne.io.read_raw_fif(file_path, preload=True, verbose=False)
            data = raw.get_data()  # Shape: (n_channels, n_samples)
            
            X_list.append(data)
            
            # Encode labels: Control = 0, PD = 1
            label = 0 if row['group'] == 'Control' else 1
            y_list.append(label)
            subject_ids.append(subject_id)
            
            # Clear memory
            del raw, data
            
        except Exception as e:
            print(f"    ✗ Error loading {subject_id}: {str(e)}")
            continue
    
    if len(X_list) == 0:
        print("✗ No data extracted!")
        return None, None, None
    
    # Convert to arrays
    print(f"\n  Converting to arrays...")
    X = np.array(X_list)  # Shape: (n_subjects, n_channels, n_samples)
    y = np.array(y_list)
    
    # Save arrays
    print(f"  Saving arrays...")
    X_file = os.path.join(output_dir, "X_preprocessed.npy")
    y_file = os.path.join(output_dir, "y_labels.npy")
    subject_ids_file = os.path.join(output_dir, "subject_ids.npy")
    
    np.save(X_file, X)
    np.save(y_file, y)
    np.save(subject_ids_file, np.array(subject_ids, dtype=object))
    
    print(f"\n✓ Data extraction complete!")
    print(f"  Data shape: {X.shape} (n_subjects, n_channels, n_samples)")
    print(f"  Labels shape: {y.shape}")
    print(f"  Control subjects: {np.sum(y == 0)}")
    print(f"  PD subjects: {np.sum(y == 1)}")
    print(f"\n  Data statistics:")
    print(f"    Mean: {X.mean():.6f}")
    print(f"    Std: {X.std():.6f}")
    print(f"    Min: {X.min():.6f}")
    print(f"    Max: {X.max():.6f}")
    print(f"\n  Files saved:")
    print(f"    {X_file}")
    print(f"    {y_file}")
    print(f"    {subject_ids_file}")
    
    # Clear memory
    del X, y, X_list, y_list
    
    return X_file, y_file, subject_ids_file

# Uncomment to extract data arrays (memory-intensive!)
# X_file, y_file, subject_ids_file = extract_data_arrays_memory_efficient(
#     summary_df, OUTPUT_DIR, max_subjects=None  # Set max_subjects for testing
# )


## 7. Save Summary and Metadata

Save the preprocessing summary for future reference.


In [22]:
# Save summary dataframe
summary_csv = os.path.join(OUTPUT_DIR, "preprocessing_summary.csv")
summary_df.to_csv(summary_csv, index=False)
print(f"✓ Saved preprocessing summary to: {summary_csv}")

# Create a metadata file with preprocessing parameters
metadata = {
    'preprocessing_parameters': {
        'target_sampling_rate_hz': TARGET_SFREQ,
        'bandpass_low_hz': LOW_FREQ,
        'bandpass_high_hz': HIGH_FREQ,
        'ica_n_components': ICA_N_COMPONENTS,
        'ica_max_iter': ICA_MAX_ITER,
        're_referencing': 'Common Average Reference (CAR)',
        'normalization': 'Z-score per channel',
        'riemannian_recentering': HAS_PYRIEMANN
    },
    'total_subjects': len(summary_df),
    'by_dataset': summary_df['dataset'].value_counts().to_dict(),
    'by_group': summary_df['group'].value_counts().to_dict(),
    'output_directory': OUTPUT_DIR
}

import json
metadata_file = os.path.join(OUTPUT_DIR, "preprocessing_metadata.json")
with open(metadata_file, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"✓ Saved preprocessing metadata to: {metadata_file}")

print("\n" + "=" * 60)
print("Preprocessing Complete!")
print("=" * 60)
print(f"\nAll preprocessed files are saved in: {OUTPUT_DIR}")
print(f"Each subject has been saved as: {{subject_id}}_preprocessed.fif")
print(f"\nTo load a preprocessed file:")
print(f"  import mne")
print(f"  raw = mne.io.read_raw_fif('{OUTPUT_DIR}/{{subject_id}}_preprocessed.fif', preload=True)")
print(f"\nTo extract data arrays for ML (optional, memory-intensive):")
print(f"  Run the cell in section 6 above.")


NameError: name 'summary_df' is not defined

## Notes and Next Steps

### Preprocessing Steps Applied:
1. **Resampling**: All data resampled to 250 Hz
2. **Filtering**: Band-pass filter applied (0.5 - 50 Hz)
3. **Re-referencing**: Common Average Reference (CAR) applied
4. **Bad Channel Detection**: Automatic detection and interpolation
5. **Artifact Removal**: ICA with correlation-based artifact detection (no ICLabel required)
6. **Normalization**: Z-score normalization per channel per subject
7. **Riemannian Re-centering**: Applied if pyriemann is available

### Memory-Efficient Design:
- Each subject is processed and saved individually
- No need to load all data into memory at once
- Can resume processing if interrupted (skips already processed subjects)
- Preprocessed files can be loaded individually for further analysis

### Loading Preprocessed Data:
```python
import mne
import os

# Load a single subject
subject_id = "sub-001"
file_path = os.path.join(OUTPUT_DIR, f"{subject_id}_preprocessed.fif")
raw = mne.io.read_raw_fif(file_path, preload=True)

# Access the data
data = raw.get_data()  # Shape: (n_channels, n_samples)
```

### For Machine Learning:
If you need to extract all data into arrays, use the function in section 6.
**Warning**: This is memory-intensive and may not be feasible for large datasets.
Consider processing in batches or using the individual .fif files directly.

The preprocessed data is now ready for further analysis!
