# Testing the model

## Load the data

In [1]:
import h5py
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
data_path = "./CHB-MIT/processed"
ictal_path = os.path.join(data_path, "ictal.h5")
interictal_path = os.path.join(data_path, "interictal.h5")

ictal_file = h5py.File(ictal_path, 'r')
interictal_file = h5py.File(interictal_path, 'r')

print(ictal_file.keys())
print(interictal_file.keys())

<KeysViewHDF5 ['channels', 'data', 'info']>
<KeysViewHDF5 ['channels', 'data', 'info']>


## Convert data to tensor

In [3]:
ictal_np = np.array(ictal_file['data'])
interictal_np = np.array(interictal_file['data'])

ictal_data = torch.tensor(ictal_np, dtype=torch.float32)
interictal_data = torch.tensor(interictal_np, dtype=torch.float32)

print(f"Ictal data shape {ictal_data.shape}")
print(f"Interictal data shape {interictal_data.shape}")

Ictal data shape torch.Size([2509, 22, 2048])
Interictal data shape torch.Size([2509, 22, 2048])


## Create Dataset

In [4]:
class EEGDataset(Dataset):
    def __init__(self, ictal_data, interictal_data):
        # Ensure the data is converted to tensors
        self.data = torch.cat([ictal_data, interictal_data])
        # Labels for ictal and interictal data
        self.labels = torch.cat(
            [
                torch.ones(len(ictal_data)),  # Ictal = 1
                torch.zeros(len(interictal_data)),  # Interictal = 0
            ]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].long()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [5]:
dataset = EEGDataset(ictal_data, interictal_data)
train_dataset, test_dataset, val_dataset = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## Implement Encoder
The encoder will implemented in a functional way

In [6]:
# Function to encode continuous EEG data into spikes using rate coding
def rate_coding(data, num_steps=100, gain=1.0):
    """
    Convert continuous EEG data to spike trains using rate coding

    Args:
        data: EEG data tensor of shape (batch_size, channels, time_steps)
        num_steps: Number of time steps for the spike train
        gain: Scaling factor to control firing rate

    Returns:
        Spike tensor of shape (batch_size, channels, num_steps)
    """
    # Normalize data to [0, 1] range for each channel
    data_min = data.min(dim=2, keepdim=True)[0]
    data_max = data.max(dim=2, keepdim=True)[0]
    data_norm = (data - data_min) / (data_max - data_min + 1e-8)

    # Scale by gain factor
    data_norm = data_norm * gain

    # Create spike tensor
    spike_data = torch.zeros(
        (data.shape[0], data.shape[1], num_steps), device=data.device
    )

    # Generate spikes using Bernoulli sampling
    for t in range(num_steps):
        spike_data[:, :, t] = torch.bernoulli(data_norm.mean(dim=2))

    return spike_data

## Testing the model
From this sample model the data is not time domain but the frequency so it need to do the sfft

In [7]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, SConv2dLSTM
from tqdm import tqdm

In [8]:
def vectorized_stft(eeg_data, n_fft=256, hop_length=32, win_length=128):
    """
    Apply STFT to batched EEG data using vectorization

    Parameters:
    -----------
    eeg_data: torch.Tensor
        EEG data with shape (batch, channels, time_steps)

    Returns:
    --------
    stft_output: torch.Tensor
        STFT output with shape (batch, channels, frequency_bins, time_frames)
    """
    batch_size, n_channels, time_steps = eeg_data.shape
    window = torch.hann_window(win_length)

    # Reshape to (batch*channels, time_steps)
    reshaped_data = eeg_data.reshape(-1, time_steps)

    # Apply STFT to all channels at once
    stft = torch.stft(
        reshaped_data,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        return_complex=True,
    )

    # Reshape back to (batch, channels, freq_bins, time_frames)
    freq_bins, time_frames = stft.shape[1], stft.shape[2]
    stft_output = stft.reshape(batch_size, n_channels, freq_bins, time_frames)

    return stft_output

In [None]:
class STFTSpikeClassifier(nn.Module):
    def __init__(
        self,
        input_channels=22,
        threshold=0.05,
        slope=13.42287274232855,
        beta=0.9181805491303656,
        p1=0.5083664100388336,
        p2=0.26260898840708335,
    ):
        super().__init__()

        spike_grad = surrogate.straight_through_estimator()
        spike_grad2 = surrogate.fast_sigmoid(slope=slope)

        # initialize layers - note input_channels=22 for your STFT data
        self.lstm1 = SConv2dLSTM(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm2 = SConv2dLSTM(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm3 = snn.SConv2dLSTM(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )

        # Calculate the flattened size based on your frequency dimension (129)
        # After 3 max-pooling layers (each dividing by 2), size becomes: 129 → 64 → 32 → 16
        # For time dimension: 1 (we process one time step at a time)
        self.fc1 = nn.Linear(
            64 * 16 * 1, 512
        )  # Adjust this based on actual output size

        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout1 = nn.Dropout(p1)
        self.fc2 = nn.Linear(512, 2)  # Assuming binary classification
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout2 = nn.Dropout(p2)

    def forward(self, x):
        # x shape: (batch, channels=22, freq=129, time=57)
        time_steps = x.size(3)

        # Initialize LIF state variables
        mem4 = self.lif1.init_leaky()
        mem5 = self.lif2.init_leaky()
        syn1, mem1 = self.lstm1.init_sconv2dlstm()
        syn2, mem2 = self.lstm2.init_sconv2dlstm()
        syn3, mem3 = self.lstm3.init_sconv2dlstm()

        # Output recording
        spk5_rec = []
        mem5_rec = []

        # Process each time step
        for step in range(time_steps):
            # Extract the current time step and prepare input
            # x_t shape: (batch, channels=22, freq=129, time=1)
            x_t = x[:, :, :, step].unsqueeze(-1)

            # Pass through SConv2dLSTM layers
            spk1, syn1, mem1 = self.lstm1(x_t, syn1, mem1)
            spk2, syn2, mem2 = self.lstm2(spk1, syn2, mem2)
            spk3, syn3, mem3 = self.lstm3(spk2, syn3, mem3)

            # Flatten and feed through fully connected layers
            cur4 = self.dropout1(self.fc1(spk3.flatten(1)))
            spk4, mem4 = self.lif1(cur4, mem4)

            cur5 = self.dropout2(self.fc2(spk4))
            spk5, mem5 = self.lif2(cur5, mem5)

            # Record output spikes and membrane potentials
            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        # Stack time steps
        return torch.stack(spk5_rec), torch.stack(mem5_rec)

In [16]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

In [18]:
import snntorch.functional as SF
from snntorch import spikegen
# Initialize the network
device = torch.device("cuda")
SNN_net = STFTSpikeClassifier().to(device)
# Loss and optimizer
criterion = SF.mse_count_loss()

optimizer = torch.optim.AdamW(
    SNN_net.parameters(),
    lr=5e-5, 
    betas=(0.9, 0.999),
    weight_decay=1e-5
)

# Add a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)


In [15]:
# Training loop
loss_hist = []
val_loss_hist = []
best_val_loss = float("inf")
num_epochs = 50
early_stopping = EarlyStopping(patience=5, delta=0.01)

for epoch in range(num_epochs):
    # Training phase
    SNN_net.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch_idx, (data, targets) in enumerate(train_loop):
        # Preprocess data on CPU first
        # STFT output: (batch, channels=22, freq=129, time=57)
        scaled_data = vectorized_stft(data)
        scaled_data = torch.abs(scaled_data)

        # Normalize data to between 0 and 1
        if scaled_data.max() > 0:  # Avoid division by zero
            scaled_data = scaled_data / scaled_data.max()

        # Apply delta encoding - this will encode when values cross threshold
        threshold = 0.1  # Adjust based on your data distribution
        # data_spike = spikegen.delta(
        #     scaled_data, threshold=threshold, padding=False, off_spike=False
        # )
        data_spike = spikegen.rate(
            scaled_data, time_var_input=True
        )

        # Move data to device after preprocessing
        data_spike, targets = data_spike.to(device), targets.to(device)

        # Forward pass
        spk_rec, mem_rec = SNN_net(data_spike)

        # Calculate loss using spikes (not just final membrane potential)
        loss_val = criterion(spk_rec, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Update metrics - use spike count for prediction
        train_loss += loss_val.item()

        # Sum spikes across time steps for prediction
        spike_sum = torch.sum(spk_rec, dim=0)
        _, predicted = torch.max(spike_sum, 1)

        total_train += targets.size(0)
        correct_train += (predicted == targets).sum().item()

        # Update progress bar
        train_loop.set_postfix(
            loss=train_loss / (batch_idx + 1), acc=100.0 * correct_train / total_train
        )

        # Store loss
        loss_hist.append(loss_val.item())

    # Validation phase
    SNN_net.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():  # No gradient calculation during validation
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        for batch_idx, (data, targets) in enumerate(val_loop):
            # Preprocess data on CPU first
            scaled_data = vectorized_stft(data)
            scaled_data = torch.abs(scaled_data)

            # Normalize if needed
            if scaled_data.max() > 0:  # Avoid division by zero
                scaled_data = scaled_data / scaled_data.max()

            # Apply delta encoding
            data_spike = spikegen.delta(
                scaled_data, threshold=threshold, padding=False, off_spike=False
            )

            # Move to device
            data_spike, targets = data_spike.to(device), targets.to(device)

            # Forward pass
            spk_rec, mem_rec = SNN_net(data_spike)

            # Calculate loss on spikes
            loss_val = criterion(spk_rec, targets)

            # Update metrics - use spike count for prediction
            val_loss += loss_val.item()

            # Sum spikes across time steps for prediction
            spike_sum = torch.sum(spk_rec, dim=0)
            _, predicted = torch.max(spike_sum, 1)

            total_val += targets.size(0)
            correct_val += (predicted == targets).sum().item()

            # Update progress bar
            val_loop.set_postfix(
                loss=val_loss / (batch_idx + 1), acc=100.0 * correct_val / total_val
            )

        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
        val_loss_hist.append(avg_val_loss)

        # Update learning rate based on validation loss
        scheduler.step(avg_val_loss)

    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(
        f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100.*correct_train/total_train:.2f}%"
    )
    print(
        f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100.*correct_val/total_val:.2f}%"
    )
    print("-" * 60)

    epoch_avg_val_loss = val_loss / len(val_loader)
    early_stopping(epoch_avg_val_loss, SNN_net)
    if early_stopping.early_stop:
        print("Early stopping")
        break

Epoch 1/50 [Train]: 100%|██████████| 110/110 [00:59<00:00,  1.85it/s, acc=78.2, loss=10.3]
Epoch 1/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.25it/s, acc=71.5, loss=17.2]


Epoch 1/50:
Train Loss: 10.3326, Train Acc: 78.20%
Val Loss: 17.1989, Val Acc: 71.46%
------------------------------------------------------------


Epoch 2/50 [Train]: 100%|██████████| 110/110 [00:58<00:00,  1.87it/s, acc=78, loss=10.5]  
Epoch 2/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.23it/s, acc=71.3, loss=17.2]


Epoch 2/50:
Train Loss: 10.5133, Train Acc: 77.97%
Val Loss: 17.2120, Val Acc: 71.26%
------------------------------------------------------------


Epoch 3/50 [Train]: 100%|██████████| 110/110 [00:59<00:00,  1.85it/s, acc=78.5, loss=10.3]
Epoch 3/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.26it/s, acc=71.3, loss=17.2]


Epoch 3/50:
Train Loss: 10.3228, Train Acc: 78.45%
Val Loss: 17.1871, Val Acc: 71.26%
------------------------------------------------------------


Epoch 4/50 [Train]:  30%|███       | 33/110 [00:18<00:42,  1.79it/s, acc=79, loss=10.2]  


KeyboardInterrupt: 