# BigP3BCI Enhanced Visualization
This notebook demonstrates a P300 pipeline with additional graphics.

## Setup
Load packages and display versions.

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import mne
import numpy as np
import seaborn as sns
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold, cross_val_score

print("MNE", mne.__version__)

## Data location
Locate BigP3BCI files, falling back to synthetic data if absent.

In [None]:
root = Path(os.getenv("NEURO_DATA_ROOT", str(Path.home() / "neuro-data")))
subject = root / "bigP3BCI-data" / "StudyA" / "A_01" / "SE001"
train_path = subject / "Train" / "CB" / "A_01_SE001_CB_Train01.edf"
test_path = subject / "Test" / "CB" / "A_01_SE001_CB_Test06.edf"
use_synth = not (train_path.exists() and test_path.exists())
if use_synth:
    print("BigP3BCI files not found – using synthetic data.")
else:
    print("Train file", train_path)
    print("Test file", test_path)

## Load raw EEG

In [None]:
if use_synth:
    info = mne.create_info(["Cz", "Pz"], sfreq=256.0, ch_types="eeg")
    rng = np.random.default_rng(42)
    data = rng.standard_normal((2, 256 * 60))
    raw_train = mne.io.RawArray(data, info)
    raw_test = raw_train.copy()
    montage = mne.channels.make_standard_montage("standard_1020")
    raw_train.set_montage(montage)
    raw_test.set_montage(montage)
else:
    raw_train = mne.io.read_raw_edf(train_path, preload=True, verbose=False)
    raw_test = mne.io.read_raw_edf(test_path, preload=True, verbose=False)
    rename_map = {
        ch: ch.replace("EEG_", "") for ch in raw_train.ch_names if ch.startswith("EEG_")
    }
    raw_train.rename_channels(rename_map)
    raw_test.rename_channels(rename_map)
    montage = mne.channels.make_standard_montage("standard_1020")
    raw_train.set_montage(montage, on_missing="ignore")
    raw_test.set_montage(montage, on_missing="ignore")

## Extract stimulus events

In [None]:
if use_synth:
    events = mne.make_fixed_length_events(raw_train, id=1, duration=1.0)
    events[:, 2] = np.random.choice([0, 1], size=len(events))
    train_events = events
else:
    stim_begin = raw_train.get_data(picks=["StimulusBegin"])[0]
    stim_type = raw_train.get_data(picks=["StimulusType"])[0]
    onsets = np.where(stim_begin > 0)[0]
    train_events = np.c_[
        onsets, np.zeros(len(onsets), int), stim_type[onsets].astype(int)
    ]
print("Train events", np.unique(train_events[:, 2], return_counts=True))

## Preprocessing
Band-pass filter and resample.

In [None]:
raw_train.filter(0.1, 30.0, fir_design="firwin", verbose=False)
raw_train.resample(128, verbose=False)

## Epoch extraction

In [None]:
tmin, tmax = -0.2, 0.8
event_id = dict(nontarget=0, target=1)
train_epochs = mne.Epochs(
    raw_train,
    train_events,
    event_id,
    tmin,
    tmax,
    baseline=(tmin, 0),
    preload=True,
    verbose=False,
)

## ERP grand average

In [None]:
plt.close("all")
evokeds = {
    "Target": train_epochs["target"].average(),
    "Non-target": train_epochs["nontarget"].average(),
}
sns.set_context("talk")
fig = mne.viz.plot_compare_evokeds(evokeds, picks="Cz", combine="mean")[0]
fig.suptitle("Grand average at Cz")
plt.show()

## Classification

In [None]:
window = train_epochs.time_as_index([0.25, 0.45])
X = train_epochs.get_data()[:, :, window[0] : window[1]].reshape(len(train_epochs), -1)
y = train_epochs.events[:, 2]
clf = LinearDiscriminantAnalysis()
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
scores = cross_val_score(clf, X, y, cv=cv)
print("Mean CV accuracy:", scores.mean())
clf.fit(X, y)

## LDA coefficient scalp map

In [None]:
coefs = clf.coef_.reshape(train_epochs.info["nchan"], -1).mean(axis=1)
evoked = mne.EvokedArray(coefs[:, None], train_epochs.info, tmin=0)
evoked.plot_topomap(times=0, scalings=1, time_format="LDA", cmap="RdBu_r")

## Confusion matrix

In [None]:
y_pred = clf.predict(X)
cm = confusion_matrix(y, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_xticklabels(["Non-target", "Target"])
ax.set_yticklabels(["Non-target", "Target"], rotation=0)
plt.show()