# BigP3BCI Demo
This notebook demonstrates a P300 classification pipeline using the BigP3BCI dataset.

## Setup
Load libraries and display versions.

In [None]:
import os
import mne
import numpy as np
from pathlib import Path
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

print("MNE", mne.__version__)

## Data location
Set the path to the dataset using the `NEURO_DATA_ROOT` environment variable (defaults to `~/neuro-data`).

In [None]:
data_root = Path(os.environ.get("NEURO_DATA_ROOT", ""))
if not data_root.exists():
    repo_base = Path.cwd() / "data"
    data_root = (
        repo_base
        / "bigp3bci-an-open-diverse-and-machine-learning-ready-p300-based-brain-computer-interface-dataset-1.0.0"
    )
subject_dir = data_root / "bigP3BCI-data" / "StudyA" / "A_01" / "SE001"
train_path = subject_dir / "Train" / "CB" / "A_01_SE001_CB_Train01.edf"
test_path = subject_dir / "Test" / "CB" / "A_01_SE001_CB_Test06.edf"
print("Train file", train_path)
print("Test file", test_path)

## Load raw EEG
We read one calibration run and one test run from the dataset.

In [None]:
raw_train = mne.io.read_raw_edf(train_path, preload=True, verbose=False)
raw_test = mne.io.read_raw_edf(test_path, preload=True, verbose=False)
print(raw_train)

## Extract stimulus events
Events are stored in the `StimulusBegin` channel. The `StimulusType` channel encodes whether the flash contained the target (1) or not (0).

In [None]:
def extract_events(raw):
    stim_begin = raw.get_data(picks=["StimulusBegin"])[0]
    stim_type = raw.get_data(picks=["StimulusType"])[0]
    onsets = np.where(stim_begin > 0)[0]
    events = np.c_[onsets, np.zeros(len(onsets), int), stim_type[onsets].astype(int)]
    return events


train_events = extract_events(raw_train)
test_events = extract_events(raw_test)
print("Train events", np.unique(train_events[:, 2], return_counts=True))

## Preprocessing
We band-pass filter from 0.1–30 Hz and resample to 128 Hz.

In [None]:
raw_train.filter(0.1, 30.0, fir_design="firwin", verbose=False)
raw_test.filter(0.1, 30.0, fir_design="firwin", verbose=False)
raw_train.resample(128, verbose=False)
raw_test.resample(128, verbose=False)

## Epoch extraction
We epoch from −0.2…0.8 s relative to each stimulus and apply baseline correction using the pre-stimulus period.

In [None]:
tmin, tmax = -0.2, 0.8
event_id = dict(nontarget=0, target=1)
train_epochs = mne.Epochs(
    raw_train,
    train_events,
    event_id=event_id,
    tmin=tmin,
    tmax=tmax,
    baseline=(tmin, 0),
    preload=True,
    verbose=False,
)
test_epochs = mne.Epochs(
    raw_test,
    test_events,
    event_id=event_id,
    tmin=tmin,
    tmax=tmax,
    baseline=(tmin, 0),
    preload=True,
    verbose=False,
)
train_epochs

## ERP grand average
Plot the average waveform for target and non-target trials.

In [None]:
train_epochs["target"].average().plot(spatial_colors=True);

## Feature extraction and classification
We vectorize the 250–450 ms window and train an LDA on calibration data, then evaluate on the test run.

In [None]:
window = train_epochs.time_as_index([0.25, 0.45])
X_train = train_epochs.get_data()[:, :, window[0] : window[1]].reshape(
    len(train_epochs), -1
)
y_train = train_epochs.events[:, 2]
X_test = test_epochs.get_data()[:, :, window[0] : window[1]].reshape(
    len(test_epochs), -1
)
y_test = test_epochs.events[:, 2]
clf = LinearDiscriminantAnalysis()
clf.fit(X_train, y_train)
score = clf.score(X_test, y_test)
print(f"Test accuracy: {score:.3f}")

## Confusion matrix
Examine classifier performance on the test run.

In [None]:
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
print(cm)
fig, ax = plt.subplots()
im = ax.imshow(cm, cmap="Blues")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_xticks([0, 1])
ax.set_xticklabels(["Non-target", "Target"])
ax.set_yticks([0, 1])
ax.set_yticklabels(["Non-target", "Target"])
for (i, j), v in np.ndenumerate(cm):
    ax.text(j, i, str(v), ha="center", va="center")
fig.colorbar(im, ax=ax)
plt.show()