# Model Testing
First we will load the data that we prepare before using h5py

In [1]:
import os

import h5py

data_path = "./CHB-MIT/processed"
ictal_path = os.path.join(data_path, "ictal.h5")
interictal_path = os.path.join(data_path, "interictal.h5")

ictal_file = h5py.File(ictal_path, 'r')
interictal_file = h5py.File(interictal_path, 'r')

print(ictal_file.keys())
print(interictal_file.keys())

<KeysViewHDF5 ['channels', 'data', 'info']>
<KeysViewHDF5 ['channels', 'data', 'info']>


Since the info is the dictionary we need to convert it back from numpy_byte

In [2]:
import ast

ictal_raw_info = ictal_file['info']
interictal_raw_info = interictal_file['info']

ictal_info = [ast.literal_eval(info_str.decode("utf-8"))
              for info_str in ictal_raw_info]
interictal_info = [ast.literal_eval(info_str.decode(
    "utf-8")) for info_str in interictal_raw_info]

ictal_info[:2]

[{'file': 'chb01_03.edf', 'start_time': 2996, 'end_time': 3004},
 {'file': 'chb01_03.edf', 'start_time': 3000, 'end_time': 3008}]

In [3]:
import numpy as np

ictal_data = np.array(ictal_file['data'])
interictal_data = np.array(interictal_file['data'])

print(f"Ictal data shape {ictal_data.shape}")
print(f"Interictal data shape {interictal_data.shape}")

Ictal data shape (2509, 22, 2048)
Interictal data shape (2509, 22, 2048)


In [4]:
import torch
from torch.utils.data import Dataset, DataLoader

class EEGDataset(Dataset):
    def __init__(self, ictal_data, interictal_data):
        self.data = torch.cat([ictal_data, interictal_data])  # Merge ictal & interictal
        self.labels = torch.cat([
            torch.ones(len(ictal_data)),  # Label 1 for ictal
            torch.zeros(len(interictal_data))  # Label 0 for interictal
        ])
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        eeg_raw = self.data[idx]  # Raw EEG data, shape: (22, 2048)
        label = self.labels[idx].long()
        return eeg_raw, label

In [5]:
ictal_tensor = torch.tensor(ictal_data, dtype=torch.float32)
interictal_tensor = torch.tensor(interictal_data, dtype=torch.float32)

dataset = EEGDataset(ictal_tensor, interictal_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [6]:
from SeqSNN.network import SpikeTemporalConvNet2D

model = SpikeTemporalConvNet2D(
    num_levels=4,  # Number of temporal layers (deeper for complex patterns)
    channel=22,  # Number of EEG channels (for CHB-MIT or general EEG data)
    dilation=2,  # Expands the receptive field to capture long-term dependencies
    stride=1,  # Keep stride low to preserve temporal resolution
    num_steps=16,  # Time steps for SNN processing (increase if more temporal info is needed)
    kernel_size=3,  # Small kernel size to extract fine-grained EEG features
    dropout=0.2,  # Regularization to prevent overfitting
    max_length=4096,  # Maximum EEG sequence length (adjust based on dataset)
    input_size=22,  # Input size (should match EEG channel count)
    hidden_size=128,  # Sufficient neurons for feature learning
    encoder_type="conv",  # Convolutional encoder for spatial-temporal feature extraction
    num_pe_neuron=10,  # Positional encoding neurons (for capturing phase-based seizure features)
    pe_type="none",  # No positional encoding (set to "learned" if needed)
    pe_mode="concat",  # Concatenates positional embeddings to input
    neuron_pe_scale=1000.0,  # Scaling factor for encoding (depends on dataset)
)

  WeightNorm.apply(module, name, dim)


In [None]:
import torch

criterion = torch.nn.CrossEntropyLoss()  # Binary classification loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):  # Train for 10 epochs
    print(f"Start epoch {epoch}")
    total_loss = 0
    for eeg_raw, labels in dataloader:
        optimizer.zero_grad()

        output_spikes, final_spikes = model(eeg_raw)  # Model should handle encoding
        loss = criterion(final_spikes, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")

Start epoch 0
