# Final Model Comparison


In [1]:
from eeg_snn_encoder.dataset import CHBMITDataset
from eeg_snn_encoder.config import PROCESSED_DATA_DIR

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-04 22:36:48.885[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: E:\Projects\snn-encoder-test[0m


In [2]:
import torch
from torch.utils.data import DataLoader, random_split

generator = torch.Generator().manual_seed(42)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.7, 0.1, 0.2], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=15, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=15, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=15, persistent_workers=True)

In [None]:
import lightning.pytorch as pl

from eeg_snn_encoder.encoders import PoissonEncoder
from eeg_snn_encoder.models.classifier import ModelConfig, EEGSTFTSpikeClassifier
from eeg_snn_encoder.models.lightning import LitEvalSeizureClassifier, OptimizerConfig

model_params: ModelConfig = {
    "threshold": 0.279848429726772,
    "slope": 7.481824299320472,
    "beta": 0.51966978465064,
    "dropout_rate1": 0.5407012533686575,
    "dropout_rate2": 0.3643962622542456,
}

optimizer_params: OptimizerConfig = {
    "lr": 9.192533042633165e-05,
    "weight_decay": 3.881504690507441e-06,
    "scheduler_factor": 0.4764851331315718,
    "scheduler_patience": 4,
}

encoder = PoissonEncoder(interval_freq=10)

model = EEGSTFTSpikeClassifier(config=model_params)

lit_model = LitEvalSeizureClassifier(
    model=model,
    optimizer_config=optimizer_params,
    spike_encoder=encoder,
)

trainer = pl.Trainer(max_epochs=100, accelerator="gpu", devices=1, logger=False)
trainer.fit(lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(lit_model, dataloaders=test_loader)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]