# Example Classification Pipelines

In [1]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

import opencortex.neuroengine.flux.base.operators  # Enable >>
from opencortex.neuroengine.flux.estimation.onnx import ONNXNode
from opencortex.neuroengine.flux.preprocessing.bandpass import BandPassFilterNode
from opencortex.neuroengine.flux.preprocessing.notch import NotchFilterNode
from opencortex.utils.loader import load_data, convert_to_mne
import matplotlib.pyplot as plt
import numpy as np

fs = 250
chs = ["Fz", "C3", "Cz", "C4", "Pz", "PO7", "Oz", "PO8"]


eeg, trigger, dataframe = load_data("../data/aep/auditory_erp_eyes_open_S1.csv", fs=fs, skiprows=5, delimiter=',')
print("Loaded data with shape:" + str(eeg.shape) + " and trigger shape: " + str(trigger.shape))
print("That means we have " + str(eeg.shape[0]) + " samples and " + str(eeg.shape[1]) + " channels.")

 # Convert to MNE format
raw_data_train = convert_to_mne(eeg, trigger, fs=fs, chs=chs, recompute=False) # recompute=True to recalculate the event labels if the values are negative

eeg, trigger, dataframe = load_data("../data/aep/auditory_erp_eyes_closed_S1.csv", fs=fs, skiprows=5, delimiter=',')
print("Loaded data with shape:" + str(eeg.shape) + " and trigger shape: " + str(trigger.shape))
print("That means we have " + str(eeg.shape[0]) + " samples and " + str(eeg.shape[1]) + " channels.")

raw_data_test = convert_to_mne(eeg, trigger, fs=fs, chs=chs, recompute=False)

Loaded data with shape:(13626, 8) and trigger shape: (13626,)
That means we have 13626 samples and 8 channels.
Creating RawArray with float64 data, n_channels=8, n_times=13626
    Range : 0 ... 13625 =      0.000 ...    54.500 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=13626
    Range : 0 ... 13625 =      0.000 ...    54.500 secs
Ready.
Loaded data with shape:(14159, 8) and trigger shape: (14159,)
That means we have 14159 samples and 8 channels.
Creating RawArray with float64 data, n_channels=8, n_times=14159
    Range : 0 ... 14158 =      0.000 ...    56.632 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=14159
    Range : 0 ... 14158 =      0.000 ...    56.632 secs
Ready.


In [3]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class SimpleEEGNet(pl.LightningModule):
    def __init__(self, n_channels=8, n_times=250, n_classes=2, lr=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr

        self.conv1 = nn.Conv1d(n_channels, 32, 5, padding=2)
        self.pool = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(32, 64, 5, padding=2)
        self.fc = nn.Linear(64 * (n_times // 4), n_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)

    def predict_step(self, batch, batch_idx):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        return self(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [4]:
def export_to_onnx(
    model: torch.nn.Module,
    output_path: str,
    input_shape: tuple,
    device: str = 'cpu',
    opset_version: int = 11,
    verify: bool = True
):
    """
    Export PyTorch model to ONNX format.

    Args:
        model: PyTorch model (or Lightning module)
        output_path: Path to save ONNX file
        input_shape: Shape of input tensor (e.g., (1, 8, 250))
        device: Device to run model on
        opset_version: ONNX opset version
        verify: If True, verify exported model with ONNX Runtime
    """
    import torch
    import numpy as np

    # Move model to device and set to eval mode
    model = model.to(device)
    model.eval()

    # Create dummy input
    dummy_input = torch.randn(*input_shape).to(device)

    # Export to ONNX
    print(f"Exporting model to {output_path}...")
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        },
        opset_version=opset_version,
        export_params=True,
        do_constant_folding=True
    )

    print(f"âœ“ Model exported to {output_path}")

    # Verify with ONNX Runtime
    if verify:
        try:
            import onnxruntime as ort

            print("Verifying exported model...")
            ort_session = ort.InferenceSession(output_path)

            # Test inference
            test_input = np.random.randn(*input_shape).astype(np.float32)
            ort_outputs = ort_session.run(None, {'input': test_input})

            # Compare with PyTorch
            with torch.no_grad():
                torch_output = model(torch.from_numpy(test_input).to(device))
                torch_output = torch_output.cpu().numpy()

            # Check if outputs match
            max_diff = np.max(np.abs(ort_outputs[0] - torch_output))
            print(f"Max difference between PyTorch and ONNX: {max_diff:.6f}")

            if max_diff < 1e-5:
                print("âœ“ ONNX model verified successfully!")
            else:
                print("âš  Warning: Outputs differ slightly (this is often normal)")

        except ImportError:
            print("âš  onnxruntime not installed, skipping verification")
            print("Install with: pip install onnxruntime")

In [5]:
from opencortex.neuroengine.flux.estimation.lightning import LightningNode
from opencortex.neuroengine.flux.preprocessing.dataset import DatasetNode
from sklearn.preprocessing import StandardScaler, LabelEncoder
from opencortex.neuroengine.flux.preprocessing.scaler import ScalerNode
from opencortex.neuroengine.flux.preprocessing.extract import ExtractNode
from opencortex.neuroengine.flux.preprocessing.epochs import EpochingNode
from opencortex.neuroengine.flux.preprocessing.events import ExtractEventsNode, FilterEventsNode, RelabelEventsNode
from opencortex.neuroengine.flux.base.sequential import Sequential

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

preprocessing = Sequential(
    NotchFilterNode((50, 60), name='NotchFilter'),
    BandPassFilterNode(0.1, 30.0, name='BandPassFilter'),
    ExtractEventsNode(stim_channel='STI', auto_label=True, name='ExtractEvents'),
    FilterEventsNode(max_event_id=90, name='FilterEvents'),
    RelabelEventsNode(target_class=1, nontarget_label=3, name='RelabelEvents'),
    EpochingNode(tmin=-0.2, tmax=0.8, baseline=(-0.1, 0.0), event_id={'T': 1, 'NT': 3}, name='Epoching'),
    ExtractNode(label_encoder=LabelEncoder(), apply_label_encoding=True, label_mapping={1: 0, 3: 1}, name='XyExtractor'),
    ScalerNode(scaler=StandardScaler(), per_channel=True, name='StdScaler'),
    DatasetNode(split_size=0.2, batch_size=8, shuffle=True, num_workers=4, name='Dataset'),
    LightningNode(
            model=SimpleEEGNet(n_channels=len(chs), n_times=250),
            trainer_config={
                'max_epochs': 5,
                'accelerator': 'cpu',
                'enable_progress_bar': True,
                'enable_model_summary': True,
                'log_every_n_steps': 1,
                'logger': tb_logger,
            },
            name='SimpleEEGNet'
        ),
    name="Preprocessing"
)

trained_model = preprocessing(raw_data_train)




Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 3.50 Hz
- Upper transition bandwidth: 3.50 Hz
- Filter length: 237 samples (0.948 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 31.50 Hz)
- Filter length: 8251 samples (33.004 s)

Finding events on: S

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | conv1     | Conv1d           | 1.3 K  | train
1 | pool      | MaxPool1d        | 0      | train
2 | conv2     | Conv1d           | 10.3 K | train
3 | fc        | Linear           | 7.9 K  | train
4 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
19.6 K    Trainable params
0         Non-trainable params
19.6 K    Total params
0.078     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\miche\Desktop\projects\OpenCortexBCI\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:428: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

C:\Users\miche\Desktop\projects\OpenCortexBCI\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:428: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9/9 [00:02<00:00,  3.69it/s, v_num=10]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/3 [00:00<?, ?it/s][A
Validation DataLoader 0:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 1/3 [00:00<00:00, 42.53it/s][A
Validation DataLoader 0:  67%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‹   | 2/3 [00:00<00:00, 31.99it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 37.53it/s][A
Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9/9 [00:02<00:00,  3.70it/s, v_num=10]      [A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/3 [00:00<?, ?it/s][A
Validation DataLoader 0:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 1/3 [00:00<00:00, 109.21it/s][A
Validation DataLoader 0:  67%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‹   | 2/3 [00:00<00:00, 116.51it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9/9 [00:05<00:00,  1.60it/s, v_num=10]


In [6]:
model_path = "example_eegnet.pth"
lightning_node = preprocessing.get_node("SimpleEEGNet")
lightning_node.save_checkpoint(model_path)

In [7]:
# Usage
model = SimpleEEGNet(n_channels=len(chs), n_times=250, n_classes=2)
#model.load_state_dict(torch.load('example_eegnet.pth'))

export_to_onnx(
    model=trained_model,
    output_path='model.onnx',
    input_shape=(1, 8, 251),  # (batch, channels, time)
    verify=True,
)

Exporting model to model.onnx...
âœ“ Model exported to model.onnx


  torch.onnx.export(


Verifying exported model...
Max difference between PyTorch and ONNX: 0.000000
âœ“ ONNX model verified successfully!


In [11]:
from opencortex.neuroengine.flux.evaluation.metrics import MetricsNode
from opencortex.neuroengine.flux.estimation.onnx import ONNXNode

fitted_scaler = preprocessing.get_node("StdScaler")
extractor_node = preprocessing.get_node("XyExtractor")


inference_pipeline = Sequential(
    NotchFilterNode((50, 60), name='NotchFilter'),
    BandPassFilterNode(0.1, 30.0, name='BandPassFilter'),
    ExtractEventsNode(stim_channel='STI', auto_label=True, name='ExtractEvents'),
    FilterEventsNode(max_event_id=90, name='FilterEvents'),
    RelabelEventsNode(target_class=1, nontarget_label=3, name='RelabelEvents'),
    EpochingNode(tmin=-0.2, tmax=0.8, baseline=(-0.1, 0.0), event_id={'T': 1, 'NT': 3}, name='Epoching'),
    extractor_node,
    fitted_scaler,
    DatasetNode(split_size=0.0, batch_size=1, shuffle=False, num_workers=0, name='TestDataset'),
    ONNXNode(model_path='model.onnx', name='ONNXInference'),
    name="Inference",

)

predictions = inference_pipeline(raw_data_test)
predictions




Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 3.50 Hz
- Upper transition bandwidth: 3.50 Hz
- Filter length: 237 samples (0.948 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 31.50 Hz)
- Filter length: 8251 samples (33.004 s)

Finding events on: S

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1])

In [12]:
%%timeit
_ = inference_pipeline(raw_data_test)

Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 3.50 Hz
- Upper transition bandwidth: 3.50 Hz
- Filter length: 237 samples (0.948 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 31.50 Hz)
- Filter length: 8251 samples (33.004 s)

Finding events on: S