# Tuning the model

## Load the data

In [1]:
import h5py
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split

## Create Dataset

In [2]:
class EEGDataset(Dataset):
    """
    PyTorch Dataset for EEG data with seizure/non-seizure labels.
    Can load data directly from HDF5 files.

    Attributes:
        data (torch.Tensor): Combined EEG data
        labels (torch.Tensor): Binary labels (1 for ictal, 0 for interictal)
    """

    def __init__(
        self, data_dir, ictal_filename="ictal.h5", interictal_filename="interictal.h5"
    ):
        """
        Initialize the dataset with ictal and interictal data,
        either directly provided or loaded from files.

        Parameters:
            data_dir (str): Directory containing HDF5 data files
            ictal_filename (str, optional): Filename for ictal data
            interictal_filename (str, optional): Filename for interictal data
        """

        ictal_path = os.path.join(data_dir, ictal_filename)
        interictal_path = os.path.join(data_dir, interictal_filename)

        ictal_file = h5py.File(ictal_path, "r")
        interictal_file = h5py.File(interictal_path, "r")

        ictal_data = torch.tensor(np.array(ictal_file["data"]), dtype=torch.float32)
        interictal_data = torch.tensor(
            np.array(interictal_file["data"]), dtype=torch.float32
        )

        # Ensure the data is converted to tensors
        self.data = torch.cat([ictal_data, interictal_data])
        # Labels for ictal and interictal data
        self.labels = torch.cat(
            [
                torch.ones(len(ictal_data)),  # Ictal = 1
                torch.zeros(len(interictal_data)),  # Interictal = 0
            ]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].bool()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [3]:
data_path = "./CHB-MIT/processed"
dataset = EEGDataset(data_path)
train_dataset, test_dataset, val_dataset = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

## Testing the model
From this sample model the data is not time domain but the frequency so it need to do the sfft

In [4]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, SConv2dLSTM
from tqdm import tqdm

In [5]:
def vectorized_stft(eeg_data, n_fft=256, hop_length=32, win_length=128):
    """
    Apply STFT to batched EEG data using vectorization

    Parameters:
    -----------
    eeg_data: torch.Tensor
        EEG data with shape (batch, channels, time_steps)

    Returns:
    --------
    stft_output: torch.Tensor
        STFT output with shape (batch, channels, frequency_bins, time_frames)
    """
    batch_size, n_channels, time_steps = eeg_data.shape
    window = torch.hann_window(win_length)

    # Reshape to (batch*channels, time_steps)
    reshaped_data = eeg_data.reshape(-1, time_steps)

    # Apply STFT to all channels at once
    stft = torch.stft(
        reshaped_data,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        return_complex=True,
    )

    # Reshape back to (batch, channels, freq_bins, time_frames)
    freq_bins, time_frames = stft.shape[1], stft.shape[2]
    stft_output = stft.reshape(batch_size, n_channels, freq_bins, time_frames)

    return stft_output

In [6]:
class STFTSpikeClassifier(nn.Module):
    def __init__(
        self,
        input_channels=22,
        threshold=0.05,
        slope=13.42287274232855,
        beta=0.9181805491303656,
        p1=0.5083664100388336,
        p2=0.26260898840708335,
    ):
        super().__init__()

        spike_grad = surrogate.straight_through_estimator()
        spike_grad2 = surrogate.fast_sigmoid(slope=slope)

        # initialize layers - note input_channels=22 for your STFT data
        self.lstm1 = SConv2dLSTM(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm2 = SConv2dLSTM(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm3 = snn.SConv2dLSTM(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )

        # Calculate the flattened size based on your frequency dimension (129)
        # After 3 max-pooling layers (each dividing by 2), size becomes: 129 → 64 → 32 → 16
        # For time dimension: 1 (we process one time step at a time)
        self.fc1 = nn.Linear(
            64 * 16 * 1, 512
        )  # Adjust this based on actual output size

        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout1 = nn.Dropout(p1)
        self.fc2 = nn.Linear(512, 2)  # Assuming binary classification
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout2 = nn.Dropout(p2)

    def forward(self, x):
        # x shape: (batch, channels=22, freq=129, time=57)
        time_steps = x.size(3)

        # Initialize LIF state variables
        mem4 = self.lif1.init_leaky()
        mem5 = self.lif2.init_leaky()
        syn1, mem1 = self.lstm1.init_sconv2dlstm()
        syn2, mem2 = self.lstm2.init_sconv2dlstm()
        syn3, mem3 = self.lstm3.init_sconv2dlstm()

        # Output recording
        spk5_rec = []
        mem5_rec = []

        # Process each time step
        for step in range(time_steps):
            # Extract the current time step and prepare input
            # x_t shape: (batch, channels=22, freq=129, time=1)
            x_t = x[:, :, :, step].unsqueeze(-1)

            # Pass through SConv2dLSTM layers
            spk1, syn1, mem1 = self.lstm1(x_t, syn1, mem1)
            spk2, syn2, mem2 = self.lstm2(spk1, syn2, mem2)
            spk3, syn3, mem3 = self.lstm3(spk2, syn3, mem3)

            # Flatten and feed through fully connected layers
            cur4 = self.dropout1(self.fc1(spk3.flatten(1)))
            spk4, mem4 = self.lif1(cur4, mem4)

            cur5 = self.dropout2(self.fc2(spk4))
            spk5, mem5 = self.lif2(cur5, mem5)

            # Record output spikes and membrane potentials
            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        # Stack time steps
        return torch.stack(spk5_rec), torch.stack(mem5_rec)

In [7]:
import snntorch.functional as SF
from snntorch import spikegen
import optuna

device = torch.device("cuda")

In [8]:
def objective(trial):
    # Define all hyperparameters in a single dictionary
    params = {
        # Model hyperparameters
        "threshold": trial.suggest_float("threshold", 0.01, 0.1),
        "slope": trial.suggest_float("slope", 5.0, 20.0),
        "beta": trial.suggest_float("beta", 0.8, 0.99),
        "p1": trial.suggest_float("p1", 0.3, 0.7),
        "p2": trial.suggest_float("p2", 0.1, 0.4),
        # Optimizer hyperparameters
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        # Scheduler hyperparameters
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.7),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 3, 10),
    }

    # Create model and optimizer using parameters from the dictionary
    model = STFTSpikeClassifier(
        input_channels=22,
        threshold=params["threshold"],
        slope=params["slope"],
        beta=params["beta"],
        p1=params["p1"],
        p2=params["p2"],
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=params["lr"],
        betas=(0.9, 0.999),
        weight_decay=params["weight_decay"],
    )

    # Create scheduler with parameters from the dictionary
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=params["scheduler_factor"],
        patience=params["scheduler_patience"],
        min_lr=1e-6,
    )

    criterion = SF.mse_count_loss()

    # Training loop
    num_epochs = 15  # Reduced for hyperparameter search
    best_val_loss = 0

    print(f"Trial {trial.number} Starting training...")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)

        for batch_idx, (data, targets) in enumerate(train_loop):
            # Preprocess data
            scaled_data = vectorized_stft(data)

            scaled_data = torch.abs(scaled_data)

            if scaled_data.max() > 0:  # Avoid division by zero
                scaled_data = scaled_data / scaled_data.max()

            data_spike = spikegen.rate(scaled_data, time_var_input=True)

            data_spike, targets = data_spike.to(device), targets.to(device)

            spk_rec, _ = model(data_spike)

            loss_val = criterion(spk_rec, targets)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            train_loss += loss_val.item()
            spike_sum = torch.sum(spk_rec, dim=0)
            _, predicted = torch.max(spike_sum, 1)
            total_train += targets.size(0)
            correct_train += (predicted == targets).sum().item()

            # Update progress bar
            train_loop.set_postfix(
                loss=train_loss / (batch_idx + 1),
                acc=100.0 * correct_train / total_train,
            )

        # Validation phase
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)

            for batch_idx, (data, targets) in enumerate(val_loop):
                # Preprocess data
                scaled_data = vectorized_stft(data)

                scaled_data = torch.abs(scaled_data)

                if scaled_data.max() > 0:  # Avoid division by zero
                    scaled_data = scaled_data / scaled_data.max()

                data_spike = spikegen.rate(scaled_data, time_var_input=True)

                data_spike, targets = data_spike.to(device), targets.to(device)

                spk_rec, _ = model(data_spike)

                loss_val = criterion(spk_rec, targets)

                val_loss += loss_val.item()
                spike_sum = torch.sum(spk_rec, dim=0)
                _, predicted = torch.max(spike_sum, 1)
                total_val += targets.size(0)
                correct_val += (predicted == targets).sum().item()

                # Update progress bar
                val_loop.set_postfix(
                    loss=val_loss / (batch_idx + 1), acc=100.0 * correct_val / total_val
                )

        # Calculate average metrics
        avg_val_loss = val_loss / len(val_loader)

        scheduler.step(avg_val_loss)

        trial.report(avg_val_loss, epoch)

        if avg_val_loss > best_val_loss:
            best_val_loss = avg_val_loss

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return avg_val_loss

In [9]:
from config import DB_CONFIG

study_name = "Classifier Rate Encoder Old"
storage_url = f"postgresql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"

study = optuna.create_study(
    direction="minimize",
    study_name=study_name,
    storage=storage_url,
    load_if_exists=True,
)

[I 2025-04-02 01:44:44,575] A new study created in RDB with name: Classifier Rate Encoder Old


Try optimizer

In [None]:
study.best_params