
# BCI Competition IV 2a - Comprehensive EDA

Exploratory analysis of the BCI Competition IV 2a motor imagery dataset, covering session inventory, signal quality metrics, spectral characteristics, and subject-level anomaly detection to support downstream channel selection and modelling.



## Analysis roadmap
- Inspect GDF files and build an inventory for training (T) and evaluation (E) sessions per subject.
- Extract session-level signal quality indicators (variance, clipping, spectral balance) using MNE.
- Visualise cue/event distributions, amplitude stability, and band-power patterns with Matplotlib/Seaborn/Plotly.
- Flag noisy or flat electrodes and subject-level anomalies to guide data cleaning.
- Provide interactive utilities and exportable CSV reports for modelling workflows.


In [None]:

import os
from pathlib import Path
from functools import lru_cache

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

from tqdm.auto import tqdm

import mne

plt.style.use('seaborn-v0_8')
sns.set_context('notebook', font_scale=1.1)
pio.renderers.default = 'notebook_connected'
mne.set_log_level('WARNING')
pd.options.display.float_format = lambda x: f"{x:0.3f}"


In [None]:

DATA_ROOT = Path('data/BCI_2a')
DERIVED_DIR = DATA_ROOT.parent / 'BCI_2a_derived'
DERIVED_DIR.mkdir(parents=True, exist_ok=True)
RUN_SUMMARY_CACHE = DERIVED_DIR / 'run_summary.pkl'
ERROR_LOG_CACHE = DERIVED_DIR / 'run_summary_errors.json'

assert DATA_ROOT.exists(), f"Expected BCI 2a files under {DATA_ROOT.as_posix()}"
GDF_FILES = sorted(DATA_ROOT.glob('A??[TE].gdf'))
print(f'Detected {len(GDF_FILES)} GDF sessions')
GDF_FILES[:6]



## Session metadata reference
- Subjects: A01-A09 (9 participants)
- Session suffix T = training (with labels); E = evaluation
- Events of interest (training):
  - 769: cue onset (left hand)
  - 770: cue onset (right hand)
  - 771: cue onset (feet)
  - 772: cue onset (tongue)
  - 783: start of trial
  - 1023/1024/1025/1026: artifact/feedback markers


In [None]:

SESSION_METADATA = pd.DataFrame({
    'code': ['T', 'E'],
    'description': ['Training (labelled)', 'Evaluation (unlabelled)']
})
SESSION_METADATA



## Inventory check
Build a matrix of available sessions per subject and identify missing files.


In [None]:

subjects = sorted({path.stem[:3] for path in GDF_FILES})
inventory_records = []
for subject in subjects:
    for session_code in ['T', 'E']:
        file_path = DATA_ROOT / f"{subject}{session_code}.gdf"
        inventory_records.append({
            'subject': subject,
            'session': session_code,
            'available': file_path.exists(),
            'file': file_path if file_path.exists() else None
        })

inventory_df = pd.DataFrame(inventory_records)

print('Missing sessions:', inventory_df.loc[~inventory_df['available'], 'session'].count())
inventory_df.head()


In [None]:

inv_matrix = (inventory_df.assign(present=lambda df: df['available'].astype(int))
                            .pivot(index='subject', columns='session', values='present')
                            .fillna(0))
plt.figure(figsize=(6, 4))
sns.heatmap(inv_matrix, cmap='Greens', annot=True, cbar=False)
plt.title('Session availability per subject (1 = present)')
plt.tight_layout()
plt.show()



### Helper utilities for session-level metrics
Summaries include amplitude stats, clipping, spectral features, and event counts. Cached outputs prevent repeated heavy computation.


In [None]:
import warnings
from typing import Dict, Any

EEG_DECIM_FOR_STATS = 4
TARGET_PSD_SFREQ = 125.0
FLAT_STD_THRESHOLD_UV = 0.25
NOISY_STD_THRESHOLD_UV = 150.0
CLIP_THRESHOLD_UV = 200.0

BAND_DEFS = {
    'delta': (1.0, 4.0),
    'theta': (4.0, 8.0),
    'alpha': (8.0, 13.0),
    'beta': (13.0, 30.0),
    'gamma': (30.0, 45.0),
}

EVENT_ID_MAP = {
    769: 'left_hand',
    770: 'right_hand',
    771: 'feet',
    772: 'tongue',
    783: 'start_trial',
}

@lru_cache(maxsize=64)
def _load_raw_cache(gdf_path: str) -> mne.io.BaseRaw:
    raw = mne.io.read_raw_gdf(gdf_path, preload=True, verbose='ERROR')
    picks = mne.pick_types(raw.info, eeg=True)
    raw.pick(picks)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)
    return raw


def summarize_session(gdf_path: Path) -> Dict[str, Any]:
    gdf_path = Path(gdf_path)
    subject_id = gdf_path.stem[:3]
    session_code = gdf_path.stem[3]

    raw = _load_raw_cache(str(gdf_path))
    raw_copy = raw.copy()

    sfreq = float(raw_copy.info['sfreq'])
    eeg_names = raw_copy.ch_names

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', category=RuntimeWarning)
        data_full = raw_copy.get_data()
    data_uv = data_full[:, ::EEG_DECIM_FOR_STATS] * 1e6

    channel_std = data_uv.std(axis=1)
    channel_ptp = data_uv.ptp(axis=1)

    flat_channels = [name for name, std in zip(eeg_names, channel_std) if std < FLAT_STD_THRESHOLD_UV]
    noisy_channels = [name for name, std in zip(eeg_names, channel_std) if std > NOISY_STD_THRESHOLD_UV]
    clip_fraction = float((np.abs(data_uv) > CLIP_THRESHOLD_UV).mean())

    events, event_id = mne.events_from_annotations(raw_copy, verbose='ERROR')
    event_counts = {}
    for code, label in EVENT_ID_MAP.items():
        count = int((events[:, 2] == code).sum())
        event_counts[label] = count

    psd_raw = raw_copy.copy()
    if abs(sfreq - TARGET_PSD_SFREQ) > 1e-6:
        psd_raw.resample(TARGET_PSD_SFREQ, npad='auto')
    
    psd_sfreq = psd_raw.info['sfreq']
    nyquist = psd_sfreq / 2.0
    fmax_psd = min(60.0, nyquist - 1.0)
    
    psd = psd_raw.compute_psd(method='welch', fmin=1.0, fmax=fmax_psd, n_fft=512,
                              n_overlap=256, verbose='ERROR')
    psd_data = psd.get_data()
    freqs = psd.freqs
    total_mask = (freqs >= 1.0) & (freqs <= 45.0)
    total_power = psd_data[:, total_mask].sum(axis=1)
    total_power[total_power == 0] = np.nan

    band_features = {}
    for band, (fmin, fmax) in BAND_DEFS.items():
        mask = (freqs >= fmin) & (freqs < fmax)
        band_power = psd_data[:, mask].sum(axis=1)
        band_ratio = band_power / total_power
        band_features[f'band_{band}_rel_power_mean'] = float(np.nanmean(band_ratio))
        band_features[f'band_{band}_rel_power_std'] = float(np.nanstd(band_ratio))

    line_mask = (freqs >= 48.0) & (freqs <= 52.0)
    beta_mask = (freqs >= 13.0) & (freqs <= 30.0)
    line_power = float(np.nanmean(psd_data[:, line_mask])) if np.any(line_mask) else np.nan
    beta_power = float(np.nanmean(psd_data[:, beta_mask])) if np.any(beta_mask) else np.nan
    line_noise_ratio = float(line_power / beta_power) if beta_power and beta_power > 0 else np.nan

    summary = {
        'subject': subject_id,
        'session': session_code,
        'sfreq': sfreq,
        'n_channels': len(eeg_names),
        'duration_s': float(raw_copy.times[-1]),
        'mean_channel_std_uv': float(np.mean(channel_std)),
        'median_channel_std_uv': float(np.median(channel_std)),
        'p90_channel_std_uv': float(np.percentile(channel_std, 90)),
        'max_channel_std_uv': float(np.max(channel_std)),
        'p95_channel_ptp_uv': float(np.percentile(channel_ptp, 95)),
        'max_channel_ptp_uv': float(np.max(channel_ptp)),
        'clip_fraction_over_200uv': clip_fraction,
        'max_abs_signal_uv': float(np.max(np.abs(data_uv))),
        'flat_channel_count': len(flat_channels),
        'noisy_channel_count': len(noisy_channels),
        'flat_channels': flat_channels,
        'noisy_channels': noisy_channels,
        'events_total': int(events.shape[0]),
    }
    summary.update(event_counts)
    summary.update(band_features)
    summary['line_noise_ratio'] = line_noise_ratio
    summary['cache_key'] = str(gdf_path.relative_to(DATA_ROOT.parent))
    return summary

In [None]:

def build_run_summary(subject_subset=None, force_recompute=False):
    subject_list = subjects if subject_subset is None else list(subject_subset)
    if RUN_SUMMARY_CACHE.exists() and not force_recompute:
        run_df = pd.read_pickle(RUN_SUMMARY_CACHE)
        error_df = pd.read_json(ERROR_LOG_CACHE) if ERROR_LOG_CACHE.exists() else pd.DataFrame()
        if not run_df.empty and 'subject' in run_df.columns:
            existing_subjects = set(run_df['subject'].unique())
            if set(subject_list).issubset(existing_subjects):
                print('Loaded cached summaries from disk')
                return run_df[run_df['subject'].isin(subject_list)].reset_index(drop=True), error_df
        print('Cached summary incomplete or empty - recomputing')

    records = []
    errors = []
    for subject in tqdm(subject_list, desc='Summarising sessions'):
        for session_code in ['T', 'E']:
            gdf_path = DATA_ROOT / f"{subject}{session_code}.gdf"
            if not gdf_path.exists():
                continue
            try:
                records.append(summarize_session(gdf_path))
            except Exception as exc:
                errors.append({
                    'subject': subject,
                    'session': session_code,
                    'file': str(gdf_path),
                    'error': repr(exc)
                })
    run_df = pd.DataFrame(records)
    error_df = pd.DataFrame(errors)
    if subject_subset is None:
        run_df.to_pickle(RUN_SUMMARY_CACHE)
        error_df.to_json(ERROR_LOG_CACHE, orient='records', indent=2)
        print(f'Persisted run summary cache to {RUN_SUMMARY_CACHE}')
    return run_df, error_df



### Build or load session-level summaries
Set SUBJECT_FILTER to work on a smaller subset while iterating.


In [None]:

SUBJECT_FILTER = None  # e.g. ['A01', 'A02']
RUN_DF, ERROR_DF = build_run_summary(subject_subset=SUBJECT_FILTER, force_recompute=False)
if RUN_DF.empty:
    RUN_DF, ERROR_DF = build_run_summary(subject_subset=SUBJECT_FILTER, force_recompute=True)
RUN_DF.head()


In [None]:
if RUN_SUMMARY_CACHE.exists():
    RUN_SUMMARY_CACHE.unlink()
    print(f'Deleted cached summary: {RUN_SUMMARY_CACHE}')
if ERROR_LOG_CACHE.exists():
    ERROR_LOG_CACHE.unlink()
    print(f'Deleted error log: {ERROR_LOG_CACHE}')
print('Cache cleared. Will force recomputation.')

In [None]:

print(f'Sessions processed: {len(RUN_DF)}')
if not ERROR_DF.empty:
    display(ERROR_DF)
else:
    print('No read errors encountered')



## Event distribution
Check cue counts per session and the balance across subjects.


In [None]:

cue_cols = ['left_hand', 'right_hand', 'feet', 'tongue']
available_cues = [col for col in cue_cols if col in RUN_DF.columns]
trial_summary = (RUN_DF[['subject', 'session'] + available_cues]
                 .fillna(0)
                 .set_index(['subject', 'session'])
                 .astype(int))
trial_summary.head()


In [None]:

trial_long = RUN_DF.melt(id_vars=['subject', 'session'],
                         value_vars=available_cues,
                         var_name='cue', value_name='count').dropna()
plt.figure(figsize=(10, 6))
sns.boxplot(data=trial_long, x='cue', y='count', hue='session')
plt.title('Cue count distribution by session type')
plt.tight_layout()
plt.show()


In [None]:

pivot_counts = (RUN_DF.pivot_table(index='subject', columns='session', values='events_total', aggfunc='sum')
                       .fillna(0))
fig = px.imshow(pivot_counts, labels=dict(color='Events'), color_continuous_scale='Viridis',
                title='Total events per subject/session')
fig.update_layout(height=400)
fig.show()



## Signal amplitude and stability
Assess flat or noisy channels and clipping behaviour.


In [None]:

RUN_DF['has_flat_issue'] = RUN_DF['flat_channel_count'] > 0
RUN_DF['has_noisy_issue'] = RUN_DF['noisy_channel_count'] > 0
RUN_DF['amp_issue_flag'] = (
    RUN_DF['has_flat_issue'] |
    RUN_DF['has_noisy_issue'] |
    (RUN_DF['mean_channel_std_uv'] < 1.0) |
    (RUN_DF['mean_channel_std_uv'] > 80.0) |
    (RUN_DF['clip_fraction_over_200uv'] > 0.01)
)

plt.figure(figsize=(10, 6))
sns.scatterplot(data=RUN_DF,
                x='mean_channel_std_uv', y='clip_fraction_over_200uv',
                hue='amp_issue_flag', style='session')
plt.axvline(1.0, linestyle='--', color='grey', linewidth=1)
plt.axvline(80.0, linestyle='--', color='grey', linewidth=1)
plt.axhline(0.01, linestyle='--', color='grey', linewidth=1)
plt.title('Mean channel stdev vs clipping fraction (BCI 2a)')
plt.tight_layout()
plt.show()


In [None]:

fig = px.box(RUN_DF, x='session', y='mean_channel_std_uv', color='session', title='Channel variance by session')
fig.update_traces(notched=True)
fig.show()



## Frequency-domain characteristics
Inspect relative band power distributions and line noise levels.


In [None]:

band_cols = [col for col in RUN_DF.columns if col.startswith('band_') and col.endswith('_mean')]
band_long = RUN_DF.melt(id_vars=['subject', 'session'],
                        value_vars=band_cols,
                        var_name='band', value_name='relative_power').dropna()
band_long['band'] = band_long['band'].str.replace('band_', '').str.replace('_rel_power_mean', '')
plt.figure(figsize=(10, 6))
sns.violinplot(data=band_long, x='band', y='relative_power', hue='session', split=True)
plt.title('Relative band power distribution by session type')
plt.tight_layout()
plt.show()


In [None]:

fig = px.scatter(RUN_DF,
                 x='line_noise_ratio', y='band_alpha_rel_power_mean',
                 color='session', symbol='amp_issue_flag',
                 hover_data=['subject'],
                 title='Line noise vs alpha power (BCI 2a)')
fig.add_vline(x=1.5, line_dash='dash', line_color='red', annotation_text='line/noise threshold')
fig.show()



### Example topography
Display alpha band power for a representative clean session.


In [None]:
from mne.viz import plot_topomap

example_row = RUN_DF.sort_values('band_alpha_rel_power_mean', ascending=False).iloc[0]
example_path = DATA_ROOT / f"{example_row['subject']}{example_row['session']}.gdf"
example_raw = _load_raw_cache(str(example_path)).copy()
if abs(example_row['sfreq'] - TARGET_PSD_SFREQ) > 1e-6:
    example_raw.resample(TARGET_PSD_SFREQ, npad='auto')

example_sfreq = example_raw.info['sfreq']
example_nyquist = example_sfreq / 2.0
example_fmax = min(45.0, example_nyquist - 1.0)

psd = example_raw.compute_psd(method='welch', fmin=1.0, fmax=example_fmax, n_fft=512,
                              n_overlap=256, verbose='ERROR')
freqs = psd.freqs
data = psd.get_data()
alpha_mask = (freqs >= 8.0) & (freqs <= 13.0)
alpha_power = data[:, alpha_mask].mean(axis=1)
fig, ax = plt.subplots(figsize=(5, 5))
plot_topomap(alpha_power, example_raw.info, axes=ax, show=False, cmap='viridis')
ax.set_title(f"Alpha power topography - {example_row['subject']} {example_row['session']}")
plt.tight_layout()
plt.show()


## Subject-level health summary
Aggregate session metrics to flag persistent issues.


In [None]:

from scipy.stats import zscore

numeric_cols = [
    'mean_channel_std_uv', 'median_channel_std_uv', 'p90_channel_std_uv',
    'max_channel_std_uv', 'p95_channel_ptp_uv', 'max_channel_ptp_uv',
    'clip_fraction_over_200uv', 'max_abs_signal_uv', 'line_noise_ratio',
    'band_alpha_rel_power_mean', 'band_beta_rel_power_mean', 'band_theta_rel_power_mean'
]

z_df = RUN_DF[['subject', 'session'] + numeric_cols].dropna()
z_scores = z_df[numeric_cols].apply(zscore, nan_policy='omit')
z_df = z_df.assign(max_abs_z=np.abs(z_scores).max(axis=1))
RUN_DF = RUN_DF.merge(z_df[['subject', 'session', 'max_abs_z']], on=['subject', 'session'], how='left')
RUN_DF['zscore_flag'] = RUN_DF['max_abs_z'] > 3.0


In [None]:

subject_summary = (
    RUN_DF.groupby('subject').agg({
        'session': 'nunique',
        'duration_s': 'sum',
        'events_total': 'sum',
        'left_hand': 'sum',
        'right_hand': 'sum',
        'feet': 'sum',
        'tongue': 'sum',
        'mean_channel_std_uv': ['median', 'min', 'max'],
        'clip_fraction_over_200uv': ['mean', 'max'],
        'line_noise_ratio': ['median', 'max'],
        'flat_channel_count': lambda s: int((s > 0).sum()),
        'noisy_channel_count': lambda s: int((s > 0).sum()),
        'amp_issue_flag': 'sum',
        'zscore_flag': 'sum'
    })
)
subject_summary.columns = ['_'.join(filter(None, col)).strip('_') for col in subject_summary.columns]
subject_summary = subject_summary.rename(columns={
    'session_nunique': 'sessions_available',
    'duration_s_sum': 'total_duration_s',
    'events_total_sum': 'total_events',
    'left_hand_sum': 'total_left',
    'right_hand_sum': 'total_right',
    'feet_sum': 'total_feet',
    'tongue_sum': 'total_tongue',
    'mean_channel_std_uv_median': 'median_mean_std_uv',
    'mean_channel_std_uv_min': 'min_mean_std_uv',
    'mean_channel_std_uv_max': 'max_mean_std_uv',
    'clip_fraction_over_200uv_mean': 'avg_clip_fraction',
    'clip_fraction_over_200uv_max': 'max_clip_fraction',
    'line_noise_ratio_median': 'median_line_noise_ratio',
    'line_noise_ratio_max': 'max_line_noise_ratio',
    'flat_channel_count_<lambda>': 'sessions_with_flat_channels',
    'noisy_channel_count_<lambda>': 'sessions_with_noisy_channels',
    'amp_issue_flag_sum': 'sessions_with_amp_issue',
    'zscore_flag_sum': 'sessions_with_zscore_outlier'
})
subject_summary['suspect_subject'] = (
    (subject_summary['sessions_with_amp_issue'] > 0) |
    (subject_summary['sessions_with_zscore_outlier'] > 0) |
    (subject_summary['max_clip_fraction'] > 0.05) |
    (subject_summary['median_line_noise_ratio'] > 1.5)
)
subject_summary.sort_values('sessions_with_amp_issue', ascending=False).head()


In [None]:

fig = px.scatter(subject_summary.reset_index(),
                 x='median_mean_std_uv', y='median_line_noise_ratio',
                 size='sessions_with_amp_issue', color='suspect_subject',
                 hover_data=['subject', 'avg_clip_fraction'],
                 title='Subject-level signal health overview (BCI 2a)')
fig.add_hline(y=1.5, line_dash='dash', line_color='red')
fig.add_vline(x=4.0, line_dash='dash', line_color='grey')
fig.show()



## Interactive session explorer
Interactively inspect channel waveforms for a given subject and session.


In [None]:

def plot_session_timeseries(subject_id: str, session_code: str, seconds: float = 8.0, channels=None):
    gdf_path = DATA_ROOT / f"{subject_id}{session_code}.gdf"
    if not gdf_path.exists():
        raise FileNotFoundError(gdf_path)
    raw = _load_raw_cache(str(gdf_path)).copy().load_data()
    if channels is None:
        channels = ['C3', 'Cz', 'C4', 'Pz']
    available = [ch for ch in channels if ch in raw.ch_names]
    data = raw.copy().pick_channels(available)
    sfreq = data.info['sfreq']
    stop = int(min(seconds * sfreq, data.n_times))
    array = data.get_data(stop=stop) * 1e6
    time = np.arange(stop) / sfreq
    df = pd.DataFrame({
        'time_s': np.tile(time, len(available)),
        'channel': np.repeat(available, stop),
        'amplitude_uV': array.reshape(len(available) * stop)
    })
    fig = px.line(df, x='time_s', y='amplitude_uV', color='channel',
                  title=f'{subject_id}{session_code} - first {seconds} s')
    fig.update_layout(xaxis_title='Time (s)', yaxis_title='Amplitude (uV)')
    fig.show()

plot_session_timeseries('A01', 'T', seconds=6)



## Export artefact flags for downstream modelling
Save session- and subject-level reports for filtering before training.


In [None]:

RUN_REPORT_PATH = DERIVED_DIR / 'bci2a_run_summary.csv'
SUBJECT_REPORT_PATH = DERIVED_DIR / 'bci2a_subject_summary.csv'

RUN_DF.sort_values(['subject', 'session']).to_csv(RUN_REPORT_PATH, index=False)
subject_summary.sort_values('suspect_subject', ascending=False).to_csv(SUBJECT_REPORT_PATH)
print(f'Exported session summary -> {RUN_REPORT_PATH}')
print(f'Exported subject summary -> {SUBJECT_REPORT_PATH}')



## Next steps
- Filter RUN_DF by mp_issue_flag or cue counts to select clean sessions.
- Use the subject-level summary to exclude problematic participants before per-person model training.
- Extend the analysis with ERD/ERS visualisations or connectivity metrics for channel selection.
