### Test raw loader

In [4]:

import sys
import os
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

for mod in list(sys.modules.keys()):
    if 'config' in mod or 'raw_loader' in mod or 'data.loaders' in mod:
        del sys.modules[mod]
from config import MARKERS, PARTICIPANT_INFO, DATA_DIR
from data.loaders.raw_loader import RawEEGLoader

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("=" * 60)
print("CONFIG CHECK")
print("=" * 60)
print("\nMARKERS:")
for k, v in MARKERS.items():
    print(f"  {k}: {v}")

print("\nPARTICIPANT_INFO:")
for sub_id, info in PARTICIPANT_INFO.items():
    crashes = info.get('crashes', [])
    crash_str = f"crashes: {[c['repeated_eeg_trial'] for c in crashes]}" if crashes else "no crashes"
    print(f"  {sub_id} ({info['name']}): {crash_str}")


CONFIG CHECK

MARKERS:
  BASELINE_START: 1
  STIMULUS_START: 2
  STIMULUS_END: 5
  STIMULUS_END_ALT: 6
  BREAK: 3
  EXPERIMENT_RESUME: 14
  EXPERIMENT_END: 15

PARTICIPANT_INFO:
  Sub01 (yannick): no crashes
  Sub02 (daniel): no crashes
  Sub03 (simon): crashes: [32]
  Sub04 (karsten): no crashes
  Sub05 (philipp): no crashes


In [5]:
# Test loading Sub03 (Simon) - should auto-exclude trial 32
print("\n" + "=" * 60)
print("TEST: Loading Sub03 (Simon) with crash handling")
print("=" * 60)

loader = RawEEGLoader()
print(f"\nLoader config:")
print(f"  event_id: {loader.config['event_id']}")
print(f"  epoch_tmin: {loader.config['epoch_tmin']}")
print(f"  epoch_tmax: {loader.config['epoch_tmax']}")
print(f"  baseline: {loader.config['baseline']}")

data = loader.load(
    eeg_path=DATA_DIR / "Sub03" / "Simon.vhdr",
    participant_id="Sub03"
)

print(f"\nLoaded data:")
print(f"  Shape: {data.X.shape}")
print(f"  n_trials: {data.X.shape[0]}")
print(f"  n_channels: {data.X.shape[1]}")
print(f"  n_timepoints: {data.X.shape[2]}")
print(f"  sfreq: {data.sfreq} Hz")
print(f"  Excluded trials: {data.metadata['excluded_trials']}")


TEST: Loading Sub03 (Simon) with crash handling

Loader config:
  event_id: {'stimulus': 2}
  epoch_tmin: -3.0
  epoch_tmax: 32.0
  baseline: (-3.0, -0.1)
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Excluded 1 trials: [32]

Loaded data:
  Shape: (77, 30, 17501)
  n_trials: 77
  n_channels: 30
  n_timepoints: 17501
  sfreq: 500 Hz
  Excluded trials: [32]


In [None]:
# Better quality check using percentiles
import numpy as np

X = data.X

print("DATA QUALITY CHECK (Percentile-based)")
print("=" * 60)

# Convert to µV for readability
X_uv = X * 1e6

print(f"\nAmplitude distribution (µV):")
percentiles = [0, 1, 5, 25, 50, 75, 95, 99, 100]
values = np.percentile(X_uv, percentiles)
for p, v in zip(percentiles, values):
    marker = ""
    if p == 0 or p == 100:
        marker = " ← extreme"
    elif p == 1 or p == 99:
        if abs(v) > 150:
            marker = " artifacts present"
    print(f"  {p:3d}th percentile: {v:8.1f} µV{marker}")

# Check what % of data is within normal range
normal_range = 150  # µV
within_range = np.mean(np.abs(X_uv) < normal_range) * 100
print(f"\n% of data within ±{normal_range} µV: {within_range:.1f}%")

# Per-trial quality
trial_max = np.abs(X_uv).max(axis=(1, 2))
good_trials = np.sum(trial_max < normal_range)
print(f"\nTrials with max amplitude < {normal_range} µV: {good_trials} / {len(X)}")

# Identify bad trials
bad_trial_idx = np.where(trial_max >= normal_range)[0]
print(f"\nTrials exceeding threshold (0-indexed): {len(bad_trial_idx)} trials")
if len(bad_trial_idx) <= 20:
    print(f"  Indices: {list(bad_trial_idx)}")
    print(f"  Max amplitudes (µV): {[f'{trial_max[i]:.0f}' for i in bad_trial_idx]}")

# Check baseline correction
baseline_end = int(3 * data.sfreq)  # First 3 seconds = baseline
baseline_means = X_uv[:, :, :baseline_end].mean(axis=(1, 2))
print(f"\nBaseline check (should be ~0 after correction):")
print(f"  Mean of baseline means: {baseline_means.mean():.2f} µV")
print(f"  Std of baseline means: {baseline_means.std():.2f} µV")
print(f"  Baseline correction appears to be working" if abs(baseline_means.mean()) < 5 else " Baseline may not be corrected properly")

DATA QUALITY CHECK (Percentile-based)

Amplitude distribution (µV):
    0th percentile:  -7740.3 µV ← extreme
    1th percentile:   -254.5 µVartifacts present
    5th percentile:    -71.3 µV
   25th percentile:    -21.9 µV
   50th percentile:     -1.9 µV
   75th percentile:     18.0 µV
   95th percentile:     59.5 µV
   99th percentile:    116.6 µV
  100th percentile:   3772.5 µV ← extreme

% of data within ±150 µV: 97.8%

Trials with max amplitude < 150 µV: 19 / 77

Trials exceeding threshold (0-indexed): 58 trials

Baseline check (should be ~0 after correction):
  Mean of baseline means: -0.26 µV
  Std of baseline means: 2.15 µV
  ✓ Baseline correction appears to be working


In [None]:
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
import numpy as np

# Get data
X = data.X * 1e6  # Convert to µV
times = np.linspace(-3, 32, X.shape[2])
ch_names = data.ch_names

# Find a good trial and a bad trial for comparison
trial_max = np.abs(X).max(axis=(1, 2))
good_trial = np.argmin(trial_max)
bad_trial = np.argmax(trial_max)

fig, axes = plt.subplots(3, 2, figsize=(14, 10))

# 1. Good trial - all channels (butterfly plot)
ax = axes[0, 0]
for ch in range(X.shape[1]):
    ax.plot(times, X[good_trial, ch, :], alpha=0.5, linewidth=0.5)
ax.axvline(0, color='r', linestyle='--', label='Stimulus onset')
ax.axvspan(-3, 0, alpha=0.1, color='green', label='Baseline')
ax.set_xlim(-3, 32)
ax.set_ylim(-150, 150)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'Good Trial #{good_trial+1} (max: {trial_max[good_trial]:.0f} µV)')

# 2. Bad trial - all channels
ax = axes[0, 1]
for ch in range(X.shape[1]):
    ax.plot(times, X[bad_trial, ch, :], alpha=0.5, linewidth=0.5)
ax.axvline(0, color='r', linestyle='--')
ax.axvspan(-3, 0, alpha=0.1, color='green')
ax.set_xlim(-3, 32)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'Bad Trial #{bad_trial+1} (max: {trial_max[bad_trial]:.0f} µV)')

# 3. Average across all trials - single channel (Cz)
ax = axes[1, 0]
ch_idx = ch_names.index('Cz') if 'Cz' in ch_names else 0
mean_signal = X[:, ch_idx, :].mean(axis=0)
std_signal = X[:, ch_idx, :].std(axis=0)
ax.plot(times, mean_signal, 'b', linewidth=1)
ax.fill_between(times, mean_signal - std_signal, mean_signal + std_signal, alpha=0.3)
ax.axvline(0, color='r', linestyle='--')
ax.axvspan(-3, 0, alpha=0.1, color='green')
ax.set_xlim(-3, 32)
ax.set_ylim(-100, 100)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'Average ± STD - Channel {ch_names[ch_idx]}')

# 4. Multiple random trials - single channel
ax = axes[1, 1]
np.random.seed(42)
random_trials = np.random.choice(len(X), size=min(10, len(X)), replace=False)
for t in random_trials:
    ax.plot(times, X[t, ch_idx, :], alpha=0.5, linewidth=0.5, label=f'Trial {t+1}')
ax.axvline(0, color='r', linestyle='--')
ax.axvspan(-3, 0, alpha=0.1, color='green')
ax.set_xlim(-3, 32)
ax.set_ylim(-150, 150)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'10 Random Trials - Channel {ch_names[ch_idx]}')

# 5. Trial variance over time
ax = axes[2, 0]
variance_over_time = X.var(axis=(0, 1))
ax.plot(times, variance_over_time, 'purple', linewidth=1)
ax.axvline(0, color='r', linestyle='--')
ax.axvspan(-3, 0, alpha=0.1, color='green')
ax.set_xlim(-3, 32)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Variance (µV²)')
ax.set_title('Variance Over Time (spikes = artifacts)')

# 6. Per-channel variability
ax = axes[2, 1]
ch_std = X.std(axis=(0, 2))
bars = ax.barh(range(len(ch_names)), ch_std, color='steelblue')
ax.axvline(100, color='r', linestyle='--', label='100 µV threshold')
ax.set_yticks(range(len(ch_names)))
ax.set_yticklabels(ch_names, fontsize=8)
ax.set_xlabel('Std Amplitude (µV)')
ax.set_title('Per-Channel Variability')
ax.legend()

plt.tight_layout()
save_path = r'c:\Users\yanni\Desktop\TUM MSEI\2. Semester\Lab\Recording\classification\notebooks\epoch_visualization.png'
plt.savefig(save_path, dpi=150)
plt.close()

print(f"Visualization saved to:\n{save_path}")

Visualization saved to:
c:\Users\yanni\Desktop\TUM MSEI\2. Semester\Lab\Recording\classification\notebooks\epoch_visualization.png

Open the PNG file to view the plots.


In [None]:
# Participant Summary Table
import numpy as np
import pandas as pd

print("=" * 70)
print("PARTICIPANT DATA QUALITY SUMMARY")
print("=" * 70)

# Amplitude threshold for "good" trials (in µV)
AMPLITUDE_THRESHOLD = 150

# Collect data for all participants
summary_data = []

for sub_id, info in PARTICIPANT_INFO.items():
    print(f"\nLoading {sub_id} ({info['name']})...", end=" ")
    
    try:
        loader = RawEEGLoader()
        eeg_path = DATA_DIR / sub_id / info['eeg_file']
        data = loader.load(eeg_path=eeg_path, participant_id=sub_id)
        
        # Get crash info
        crashes = info.get('crashes', [])
        crash_trials = [c['repeated_eeg_trial'] for c in crashes] if crashes else []
        
        # Convert to µV and analyze
        X_uv = data.X * 1e6
        n_trials = X_uv.shape[0]
        
        # Per-trial max amplitude
        trial_max = np.abs(X_uv).max(axis=(1, 2))
        
        # Good trials: max amplitude below threshold
        good_trials = np.sum(trial_max < AMPLITUDE_THRESHOLD)
        bad_trials = n_trials - good_trials
        pct_good = (good_trials / n_trials) * 100
        
        # Bad trial indices (0-indexed)
        bad_trial_indices = np.where(trial_max >= AMPLITUDE_THRESHOLD)[0].tolist()
        
        # Get amplitude stats
        mean_max_amp = trial_max.mean()
        max_max_amp = trial_max.max()
        
        summary_data.append({
            'Participant': sub_id,
            'Name': info['name'].capitalize(),
            'Total Trials': n_trials,
            'Crash Excluded': len(crash_trials),
            'Good Trials': good_trials,
            'Bad Trials': bad_trials,
            '% Good': f"{pct_good:.1f}%",
            'Mean Max (µV)': f"{mean_max_amp:.0f}",
            'Max Max (µV)': f"{max_max_amp:.0f}",
            'Bad Trial Indices': bad_trial_indices if len(bad_trial_indices) <= 10 else f"{len(bad_trial_indices)} trials"
        })
        
        print("OK")
        
    except Exception as e:
        print(f"FAILED: {e}")
        summary_data.append({
            'Participant': sub_id,
            'Name': info['name'].capitalize(),
            'Total Trials': 'ERROR',
            'Crash Excluded': len(info.get('crashes', [])),
            'Good Trials': '-',
            'Bad Trials': '-',
            '% Good': '-',
            'Mean Max (µV)': '-',
            'Max Max (µV)': '-',
            'Bad Trial Indices': str(e)[:30]
        })

# Create DataFrame and display
df = pd.DataFrame(summary_data)

print("\n" + "=" * 70)
print(f"SUMMARY TABLE (threshold: ±{AMPLITUDE_THRESHOLD} µV)")
print("=" * 70)

# Display main stats
display_cols = ['Participant', 'Name', 'Total Trials', 'Crash Excluded', 
                'Good Trials', 'Bad Trials', '% Good', 'Mean Max (µV)', 'Max Max (µV)']
print(df[display_cols].to_string(index=False))

# Summary statistics
print("\n" + "-" * 70)
valid_rows = df[df['Total Trials'] != 'ERROR']
if len(valid_rows) > 0:
    total_trials = sum([r for r in valid_rows['Total Trials'] if isinstance(r, int)])
    total_good = sum([r for r in valid_rows['Good Trials'] if isinstance(r, int)])
    total_bad = sum([r for r in valid_rows['Bad Trials'] if isinstance(r, int)])
    total_crash = sum([r for r in valid_rows['Crash Excluded'] if isinstance(r, int)])
    
    print(f"TOTALS: {total_trials} trials across {len(valid_rows)} participants")
    print(f"        {total_good} good trials ({total_good/total_trials*100:.1f}%)")
    print(f"        {total_bad} trials with artifacts (>{AMPLITUDE_THRESHOLD} µV)")
    print(f"        {total_crash} trials excluded due to crashes")

# Show bad trial indices per participant
print("\n" + "=" * 70)
print("BAD TRIAL INDICES (0-indexed, for exclusion)")
print("=" * 70)
for row in summary_data:
    if isinstance(row['Bad Trial Indices'], list):
        if len(row['Bad Trial Indices']) > 0:
            print(f"{row['Participant']}: {row['Bad Trial Indices']}")
        else:
            print(f"{row['Participant']}: None (all trials clean!)")
    else:
        print(f"{row['Participant']}: {row['Bad Trial Indices']}")

### Test ICA Preprocessing

In [None]:
# Test ICA preprocessing on Sub03
# Reload the module to pick up changes
for mod in list(sys.modules.keys()):
    if 'raw_loader' in mod or 'data.loaders' in mod:
        del sys.modules[mod]
from data.loaders.raw_loader import RawEEGLoader

print("=" * 70)
print("TEST: Loading Sub03 with ICA artifact removal")
print("=" * 70)

# Configure loader with ICA enabled
ica_config = {
    'apply_ica': True,
    'ica_n_components': 15,          # Use 15 components (out of 30 channels)
    'ica_method': 'fastica',
    'ica_eog_threshold': 0.4,        # Correlation threshold for EOG detection
    'ica_eog_channels': None,        # Auto-detect using frontal channels
}

loader_ica = RawEEGLoader(preprocessing_config=ica_config)
print(f"\nLoader config (ICA settings):")
print(f"  apply_ica: {loader_ica.config['apply_ica']}")
print(f"  ica_n_components: {loader_ica.config['ica_n_components']}")
print(f"  ica_method: {loader_ica.config['ica_method']}")
print(f"  ica_eog_threshold: {loader_ica.config['ica_eog_threshold']}")

print("\nLoading data with ICA (this may take a minute)...")
data_ica = loader_ica.load(
    eeg_path=DATA_DIR / "Sub03" / "Simon.vhdr",
    participant_id="Sub03"
)

print(f"\nLoaded data:")
print(f"  Shape: {data_ica.X.shape}")
print(f"  n_trials: {data_ica.X.shape[0]}")

# Check ICA metadata
if 'ica' in data_ica.metadata:
    ica_info = data_ica.metadata['ica']
    print(f"\nICA Info:")
    print(f"  Detection method: {ica_info['detection_method']}")
    print(f"  Components excluded: {ica_info['excluded_components']}")
    print(f"  N components used: {ica_info['n_components']}")

In [None]:
# Compare data quality: Before vs After ICA
import numpy as np

print("=" * 70)
print("COMPARISON: Data Quality Before vs After ICA")
print("=" * 70)

# Load without ICA for comparison (use the previously loaded data)
X_before = data.X * 1e6  # From earlier cell (no ICA)
X_after = data_ica.X * 1e6  # With ICA

THRESHOLD = 150  # µV

# Calculate statistics
def compute_quality_stats(X, name):
    trial_max = np.abs(X).max(axis=(1, 2))
    good_trials = np.sum(trial_max < THRESHOLD)
    pct_good = (good_trials / len(X)) * 100
    
    percentiles = np.percentile(X, [1, 99])
    pct_within_range = np.mean(np.abs(X) < THRESHOLD) * 100
    
    return {
        'name': name,
        'n_trials': len(X),
        'good_trials': good_trials,
        'pct_good': pct_good,
        'p1': percentiles[0],
        'p99': percentiles[1],
        'pct_within_range': pct_within_range,
        'max_amplitude': np.abs(X).max(),
        'mean_max_per_trial': trial_max.mean(),
    }

stats_before = compute_quality_stats(X_before, "Without ICA")
stats_after = compute_quality_stats(X_after, "With ICA")

# Print comparison table
print(f"\n{'Metric':<30} {'Without ICA':>15} {'With ICA':>15} {'Change':>12}")
print("-" * 72)
print(f"{'Good trials (< 150µV)':<30} {stats_before['good_trials']:>15} {stats_after['good_trials']:>15} {stats_after['good_trials'] - stats_before['good_trials']:>+12}")
print(f"{'% Good trials':<30} {stats_before['pct_good']:>14.1f}% {stats_after['pct_good']:>14.1f}% {stats_after['pct_good'] - stats_before['pct_good']:>+11.1f}%")
print(f"{'% Data within ±150µV':<30} {stats_before['pct_within_range']:>14.1f}% {stats_after['pct_within_range']:>14.1f}% {stats_after['pct_within_range'] - stats_before['pct_within_range']:>+11.1f}%")
print(f"{'1st percentile (µV)':<30} {stats_before['p1']:>15.1f} {stats_after['p1']:>15.1f} {stats_after['p1'] - stats_before['p1']:>+12.1f}")
print(f"{'99th percentile (µV)':<30} {stats_before['p99']:>15.1f} {stats_after['p99']:>15.1f} {stats_after['p99'] - stats_before['p99']:>+12.1f}")
print(f"{'Max amplitude (µV)':<30} {stats_before['max_amplitude']:>15.1f} {stats_after['max_amplitude']:>15.1f} {stats_after['max_amplitude'] - stats_before['max_amplitude']:>+12.1f}")
print(f"{'Mean max per trial (µV)':<30} {stats_before['mean_max_per_trial']:>15.1f} {stats_after['mean_max_per_trial']:>15.1f} {stats_after['mean_max_per_trial'] - stats_before['mean_max_per_trial']:>+12.1f}")

# Summary
print("\n" + "-" * 72)
improvement = stats_after['pct_good'] - stats_before['pct_good']
if improvement > 0:
    print(f"ICA improved data quality: {improvement:.1f}% more clean trials")
elif improvement < 0:
    print(f"ICA reduced clean trials by {-improvement:.1f}% (may need threshold adjustment)")
else:
    print("ICA had no effect on trial quality at this threshold")