# PhysioNet Motor Imagery - Epoching and Trial Extraction

This notebook extracts motor imagery trials from preprocessed PhysioNet data following the methodology from:
**Sun et al. (2023) - Graph Convolution Neural Network Based End-to-End Channel Selection**

## Pipeline:
1. Load preprocessed FIF files (128 Hz, 0.5-40 Hz filtered, CAR referenced)
2. Extract events from annotations (T1: left fist, T2: right fist)
3. Create epochs with proper time windows
4. Apply baseline correction
5. Reject bad epochs based on amplitude criteria
6. Export to NumPy format for deep learning

## Parameters from Paper:
- Epoch window: -1.0 to 5.0 seconds (relative to cue onset)
- Baseline: -0.5 to 0.0 seconds
- Rejection criteria: EEG > 100 μV, flat < 1 μV
- Binary classification: Left fist (T1) vs Right fist (T2)

In [None]:
from pathlib import Path
import warnings
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import mne
import matplotlib.pyplot as plt
import seaborn as sns

mne.set_log_level('WARNING')
warnings.filterwarnings('ignore', category=RuntimeWarning)

sns.set_context('notebook', font_scale=1.1)
plt.style.use('seaborn-v0_8')

DATA_ROOT = Path('data/physionet')
DERIVED_DIR = DATA_ROOT / 'derived'
PREPROCESSED_DIR = DERIVED_DIR / 'preprocessed'
EPOCHS_DIR = DERIVED_DIR / 'epochs'
EPOCHS_DIR.mkdir(parents=True, exist_ok=True)

INDEX_PATH = DERIVED_DIR / 'physionet_preprocessed_index.csv'
assert INDEX_PATH.exists(), 'Run preprocessing notebook first'

print(f"Loading preprocessed index from: {INDEX_PATH}")
preprocessed_df = pd.read_csv(INDEX_PATH)
print(f"Found {len(preprocessed_df)} preprocessed runs")
preprocessed_df.head()

## Epoching Configuration

Following the paper's specifications for PhysioNet dataset

In [None]:
EPOCHING_CONFIG = {
    'tmin': -1.0,  # Start 1 second before cue
    'tmax': 5.0,   # End 5 seconds after cue (total 6 seconds)
    'baseline': (-0.5, 0),  # Baseline correction window
    'reject_criteria': {'eeg': 100e-6},  # Reject if amplitude > 100 μV
    'flat_criteria': {'eeg': 1e-6},  # Reject if amplitude < 1 μV
    'event_id': {'T1': 1, 'T2': 2},  # T1=left fist, T2=right fist
    'picks': 'eeg',
    'preload': True
}

print("Epoching Configuration:")
for key, value in EPOCHING_CONFIG.items():
    print(f"  {key}: {value}")

## Epoching Functions

In [None]:
def extract_epochs_from_run(fif_path, config):
    """
    Extract motor imagery epochs from a preprocessed FIF file.
    
    Parameters
    ----------
    fif_path : Path
        Path to preprocessed FIF file
    config : dict
        Epoching configuration
        
    Returns
    -------
    epochs_info : dict
        Dictionary with epoching results and metadata
    """
    fif_path = Path(fif_path)
    
    epochs_info = {
        'subject': fif_path.parent.name,
        'run': fif_path.stem.split('_')[0][-3:],
        'status': 'processing',
        'timestamp': datetime.now().isoformat()
    }
    
    try:
        # Load preprocessed data
        raw = mne.io.read_raw_fif(fif_path, preload=True, verbose='ERROR')
        
        epochs_info['sfreq'] = raw.info['sfreq']
        epochs_info['n_channels'] = len(raw.ch_names)
        epochs_info['duration_s'] = raw.times[-1]
        
        # Extract events from annotations
        events, event_dict = mne.events_from_annotations(raw, verbose='ERROR')
        
        # Filter for T1 and T2 events only (left/right fist)
        valid_events = ['T1', 'T2']
        event_id = {k: v for k, v in event_dict.items() if k in valid_events}
        
        if len(event_id) == 0:
            epochs_info['status'] = 'skipped'
            epochs_info['reason'] = 'No T1/T2 events found'
            return epochs_info
        
        epochs_info['events_found'] = {k: int((events[:, 2] == v).sum()) for k, v in event_id.items()}
        epochs_info['total_events'] = sum(epochs_info['events_found'].values())
        
        # Create epochs
        epochs = mne.Epochs(
            raw,
            events,
            event_id=event_id,
            tmin=config['tmin'],
            tmax=config['tmax'],
            baseline=config['baseline'],
            picks=config['picks'],
            preload=config['preload'],
            reject=config['reject_criteria'],
            flat=config['flat_criteria'],
            verbose='ERROR'
        )
        
        epochs_info['n_epochs_before_rejection'] = len(epochs.events)
        
        # Drop bad epochs
        epochs.drop_bad()
        
        epochs_info['n_epochs_after_rejection'] = len(epochs)
        epochs_info['rejection_rate'] = (epochs_info['n_epochs_before_rejection'] - 
                                          epochs_info['n_epochs_after_rejection']) / epochs_info['n_epochs_before_rejection']
        
        if len(epochs) == 0:
            epochs_info['status'] = 'failed'
            epochs_info['error'] = 'All epochs rejected'
            return epochs_info
        
        # Get epoch data and labels
        data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
        labels = epochs.events[:, 2]  # Event codes
        
        # Convert labels to 0/1 (T1=0, T2=1)
        label_map = {event_id['T1']: 0, event_id['T2']: 1}
        labels = np.array([label_map[l] for l in labels])
        
        epochs_info['label_distribution'] = {
            'left_fist': int((labels == 0).sum()),
            'right_fist': int((labels == 1).sum())
        }
        
        # Save epochs as NPZ
        out_dir = EPOCHS_DIR / epochs_info['subject']
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"{epochs_info['subject']}{epochs_info['run']}_epo.npz"
        
        np.savez_compressed(
            out_path,
            data=data,
            labels=labels,
            ch_names=epochs.ch_names,
            sfreq=epochs.info['sfreq'],
            times=epochs.times
        )
        
        epochs_info['output_path'] = str(out_path)
        epochs_info['data_shape'] = data.shape
        epochs_info['file_size_mb'] = out_path.stat().st_size / (1024 * 1024)
        epochs_info['status'] = 'success'
        
    except Exception as e:
        epochs_info['status'] = 'error'
        epochs_info['error'] = str(e)
    
    return epochs_info

## Batch Epoching

Process all motor imagery runs (exclude resting state R01/R02)

In [None]:
# Filter for motor imagery runs only (exclude resting state)
mi_runs = preprocessed_df[
    (preprocessed_df['category'] == 'motor_imagery') | 
    (preprocessed_df['category'] == 'motor_execution')
].copy()

print(f"Processing {len(mi_runs)} motor imagery/execution runs...\n")

epoch_records = []
error_records = []

for _, row in tqdm(mi_runs.iterrows(), total=len(mi_runs), desc='Epoching'):
    try:
        result = extract_epochs_from_run(row['path'], EPOCHING_CONFIG)
        
        result.update({
            'category': row['category'],
            'task': row.get('task', '')
        })
        
        if result['status'] == 'error':
            error_records.append(result)
        else:
            epoch_records.append(result)
            
    except Exception as e:
        error_records.append({
            'subject': row['subject'],
            'run': row['run'],
            'status': 'error',
            'error': str(e)
        })

epochs_df = pd.DataFrame(epoch_records)
errors_df = pd.DataFrame(error_records)

# Save index
EPOCH_INDEX_PATH = DERIVED_DIR / 'physionet_epochs_index.csv'
epochs_df.to_csv(EPOCH_INDEX_PATH, index=False)

if len(errors_df) > 0:
    ERROR_PATH = DERIVED_DIR / 'physionet_epoching_errors.csv'
    errors_df.to_csv(ERROR_PATH, index=False)
    print(f"\nErrors occurred: {len(errors_df)}")
    print(f"Error log saved to: {ERROR_PATH}")

print(f"\nEpoching complete!")
print(f"  - Successfully epoched: {len(epochs_df)}")
print(f"  - Errors: {len(errors_df)}")
print(f"  - Index saved to: {EPOCH_INDEX_PATH}")

epochs_df.head(10)

## Quality Assessment

In [None]:
if len(epochs_df) > 0:
    # Parse label distribution
    epochs_df['left_fist_count'] = epochs_df['label_distribution'].apply(
        lambda x: eval(x)['left_fist'] if isinstance(x, str) else x['left_fist']
    )
    epochs_df['right_fist_count'] = epochs_df['label_distribution'].apply(
        lambda x: eval(x)['right_fist'] if isinstance(x, str) else x['right_fist']
    )
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Epochs per run
    axes[0, 0].hist(epochs_df['n_epochs_after_rejection'], bins=20, 
                    color='steelblue', edgecolor='black', alpha=0.7)
    axes[0, 0].set_title('Distribution of Epochs per Run', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Number of Epochs')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Rejection rate
    axes[0, 1].hist(epochs_df['rejection_rate'] * 100, bins=20, 
                    color='coral', edgecolor='black', alpha=0.7)
    axes[0, 1].set_title('Epoch Rejection Rate', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Rejection Rate (%)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Label balance
    label_data = pd.DataFrame({
        'Class': ['Left Fist'] * len(epochs_df) + ['Right Fist'] * len(epochs_df),
        'Count': list(epochs_df['left_fist_count']) + list(epochs_df['right_fist_count'])
    })
    sns.boxplot(data=label_data, x='Class', y='Count', ax=axes[1, 0], palette='Set2')
    axes[1, 0].set_title('Label Distribution per Run', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('Epochs per Class')
    axes[1, 0].grid(True, alpha=0.3, axis='y')
    
    # Epochs by category
    category_counts = epochs_df.groupby('category')['n_epochs_after_rejection'].sum()
    axes[1, 1].bar(range(len(category_counts)), category_counts.values, 
                   color=['#2ca02c', '#d62728'], alpha=0.7)
    axes[1, 1].set_xticks(range(len(category_counts)))
    axes[1, 1].set_xticklabels(category_counts.index, rotation=45, ha='right')
    axes[1, 1].set_title('Total Epochs by Category', fontsize=12, fontweight='bold')
    axes[1, 1].set_ylabel('Total Epochs')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    print("\nEpoching Statistics:")
    print(f"  - Total runs processed: {len(epochs_df)}")
    print(f"  - Total epochs extracted: {epochs_df['n_epochs_after_rejection'].sum()}")
    print(f"  - Mean epochs per run: {epochs_df['n_epochs_after_rejection'].mean():.2f}")
    print(f"  - Mean rejection rate: {epochs_df['rejection_rate'].mean() * 100:.2f}%")
    print(f"  - Total left fist epochs: {epochs_df['left_fist_count'].sum()}")
    print(f"  - Total right fist epochs: {epochs_df['right_fist_count'].sum()}")
    print(f"  - Total storage used: {epochs_df['file_size_mb'].sum():.2f} MB")
else:
    print("No data to visualize. Run the epoching cell first.")

## Subject-Level Summary

Aggregate epochs per subject for model training

In [None]:
if len(epochs_df) > 0:
    subject_summary = epochs_df.groupby('subject').agg({
        'run': 'count',
        'n_epochs_after_rejection': 'sum',
        'left_fist_count': 'sum',
        'right_fist_count': 'sum',
        'rejection_rate': 'mean',
        'file_size_mb': 'sum'
    }).rename(columns={
        'run': 'n_runs',
        'n_epochs_after_rejection': 'total_epochs'
    })
    
    subject_summary['class_balance_ratio'] = (subject_summary['left_fist_count'] / 
                                               subject_summary['right_fist_count'])
    
    SUBJECT_SUMMARY_PATH = DERIVED_DIR / 'physionet_subject_epochs_summary.csv'
    subject_summary.to_csv(SUBJECT_SUMMARY_PATH)
    
    print(f"\nSubject summary saved to: {SUBJECT_SUMMARY_PATH}")
    print(f"\nTop 10 subjects by epoch count:")
    print(subject_summary.sort_values('total_epochs', ascending=False).head(10))
    
    print(f"\nDataset ready for training!")
    print(f"  - Total subjects: {len(subject_summary)}")
    print(f"  - Total trials: {subject_summary['total_epochs'].sum()}")
    print(f"  - Mean trials per subject: {subject_summary['total_epochs'].mean():.1f}")

## Summary and Next Steps

### What Was Done

1. **Event Extraction**: Extracted T1 (left fist) and T2 (right fist) events from annotations
2. **Epoching**: Created 6-second epochs (-1 to +5 seconds relative to cue)
3. **Baseline Correction**: Applied baseline correction using -0.5 to 0 seconds
4. **Quality Control**: Rejected epochs with amplitude > 100 μV or < 1 μV
5. **Data Export**: Saved epochs as compressed NPZ files for efficient loading

### Output Files

- **Epoch NPZ files**: `data/physionet/derived/epochs/{subject}/{subject}{run}_epo.npz`
- **Epoch index**: `data/physionet/derived/physionet_epochs_index.csv`
- **Subject summary**: `data/physionet/derived/physionet_subject_epochs_summary.csv`

### Data Format

Each NPZ file contains:
- `data`: (n_epochs, n_channels, n_times) - EEG data
- `labels`: (n_epochs,) - Binary labels (0=left fist, 1=right fist)
- `ch_names`: List of channel names
- `sfreq`: Sampling frequency (128 Hz)
- `times`: Time array for epochs

### Next Steps

1. **Model Implementation**: Build EEG-ARNN architecture (TFEM + CARM)
2. **Training**: Train on PhysioNet data with 10-fold cross-validation
3. **Channel Selection**: Apply ES and AS methods
4. **Evaluation**: Compare with baseline methods and analyze results