# Tuning the model

## Load the data

In [2]:
import h5py
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split

In [3]:
data_path = "./CHB-MIT/processed"
ictal_path = os.path.join(data_path, "ictal.h5")
interictal_path = os.path.join(data_path, "interictal.h5")

ictal_file = h5py.File(ictal_path, 'r')
interictal_file = h5py.File(interictal_path, 'r')

print(ictal_file.keys())
print(interictal_file.keys())

<KeysViewHDF5 ['channels', 'data', 'info']>
<KeysViewHDF5 ['channels', 'data', 'info']>


## Convert data to tensor

In [4]:
ictal_np = np.array(ictal_file['data'])
interictal_np = np.array(interictal_file['data'])

ictal_data = torch.tensor(ictal_np, dtype=torch.float32)
interictal_data = torch.tensor(interictal_np, dtype=torch.float32)

print(f"Ictal data shape {ictal_data.shape}")
print(f"Interictal data shape {interictal_data.shape}")

Ictal data shape torch.Size([2509, 22, 2048])
Interictal data shape torch.Size([2509, 22, 2048])


## Create Dataset

In [5]:
class EEGDataset(Dataset):
    def __init__(self, ictal_data, interictal_data):
        # Ensure the data is converted to tensors
        self.data = torch.cat([ictal_data, interictal_data])
        # Labels for ictal and interictal data
        self.labels = torch.cat(
            [
                torch.ones(len(ictal_data)),  # Ictal = 1
                torch.zeros(len(interictal_data)),  # Interictal = 0
            ]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].bool()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [6]:
dataset = EEGDataset(ictal_data, interictal_data)
train_dataset, test_dataset, val_dataset = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [7]:
subset_labels = ("Train", "Test", "Validation")

for i, data in enumerate([train_dataset, test_dataset, val_dataset]):
    label_counts = {0: 0, 1: 0}

    for idx in data.indices:
        label = data.dataset.labels[idx].item()
        label_counts[label] += 1

    total = sum(label_counts.values())
    
    print(f"Dataset: {subset_labels[i]}")
    print(f"  Total samples: {total}")
    print(f"  Ictal (seizure): {label_counts[1]} ({label_counts[1]/total:.2%})")
    print(f"  Interictal (normal): {label_counts[0]} ({label_counts[0]/total:.2%})")
    print(f"  Ratio: {label_counts[1]/label_counts[0]:.2f}")
    print("-" * 40)

Dataset: Train
  Total samples: 3513
  Ictal (seizure): 1770 (50.38%)
  Interictal (normal): 1743 (49.62%)
  Ratio: 1.02
----------------------------------------
Dataset: Test
  Total samples: 1004
  Ictal (seizure): 507 (50.50%)
  Interictal (normal): 497 (49.50%)
  Ratio: 1.02
----------------------------------------
Dataset: Validation
  Total samples: 501
  Ictal (seizure): 232 (46.31%)
  Interictal (normal): 269 (53.69%)
  Ratio: 0.86
----------------------------------------


## Testing the model
From this sample model the data is not time domain but the frequency so it need to do the sfft

In [8]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, SConv2dLSTM
from tqdm import tqdm

In [27]:
def vectorized_stft(eeg_data, n_fft=256, hop_length=32, win_length=128):
    """
    Apply STFT to batched EEG data using vectorization

    Parameters:
    -----------
    eeg_data: torch.Tensor
        EEG data with shape (batch, channels, time_steps)

    Returns:
    --------
    stft_output: torch.Tensor
        STFT output with shape (batch, channels, frequency_bins, time_frames)
    """
    batch_size, n_channels, time_steps = eeg_data.shape
    window = torch.hann_window(win_length)

    # Reshape to (batch*channels, time_steps)
    reshaped_data = eeg_data.reshape(-1, time_steps)

    # Apply STFT to all channels at once
    stft = torch.stft(
        reshaped_data,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        return_complex=True,
    )

    # Reshape back to (batch, channels, freq_bins, time_frames)
    freq_bins, time_frames = stft.shape[1], stft.shape[2]
    stft_output = stft.reshape(batch_size, n_channels, freq_bins, time_frames)

    return stft_output

In [28]:
class STFTSpikeClassifier(nn.Module):
    def __init__(
        self,
        input_channels=22,
        threshold=0.05,
        slope=13.42287274232855,
        beta=0.9181805491303656,
        p1=0.5083664100388336,
        p2=0.26260898840708335,
    ):
        super().__init__()

        spike_grad = surrogate.straight_through_estimator()
        spike_grad2 = surrogate.fast_sigmoid(slope=slope)

        # initialize layers - note input_channels=22 for your STFT data
        self.lstm1 = SConv2dLSTM(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm2 = SConv2dLSTM(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )
        self.lstm3 = snn.SConv2dLSTM(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=threshold,
            spike_grad=spike_grad,
        )

        # Calculate the flattened size based on your frequency dimension (129)
        # After 3 max-pooling layers (each dividing by 2), size becomes: 129 → 64 → 32 → 16
        # For time dimension: 1 (we process one time step at a time)
        self.fc1 = nn.Linear(
            64 * 16 * 1, 512
        )  # Adjust this based on actual output size

        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout1 = nn.Dropout(p1)
        self.fc2 = nn.Linear(512, 2)  # Assuming binary classification
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=threshold)
        self.dropout2 = nn.Dropout(p2)

    def forward(self, x):
        # x shape: (batch, channels=22, freq=129, time=57)
        time_steps = x.size(3)

        # Initialize LIF state variables
        mem4 = self.lif1.init_leaky()
        mem5 = self.lif2.init_leaky()
        syn1, mem1 = self.lstm1.init_sconv2dlstm()
        syn2, mem2 = self.lstm2.init_sconv2dlstm()
        syn3, mem3 = self.lstm3.init_sconv2dlstm()

        # Output recording
        spk5_rec = []
        mem5_rec = []

        # Process each time step
        for step in range(time_steps):
            # Extract the current time step and prepare input
            # x_t shape: (batch, channels=22, freq=129, time=1)
            x_t = x[:, :, :, step].unsqueeze(-1)

            # Pass through SConv2dLSTM layers
            spk1, syn1, mem1 = self.lstm1(x_t, syn1, mem1)
            spk2, syn2, mem2 = self.lstm2(spk1, syn2, mem2)
            spk3, syn3, mem3 = self.lstm3(spk2, syn3, mem3)

            # Flatten and feed through fully connected layers
            cur4 = self.dropout1(self.fc1(spk3.flatten(1)))
            spk4, mem4 = self.lif1(cur4, mem4)

            cur5 = self.dropout2(self.fc2(spk4))
            spk5, mem5 = self.lif2(cur5, mem5)

            # Record output spikes and membrane potentials
            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        # Stack time steps
        return torch.stack(spk5_rec), torch.stack(mem5_rec)

In [29]:
import snntorch.functional as SF
from snntorch import spikegen
import optuna

device = torch.device("cuda")

In [51]:
def objective(trial: optuna.Trial):
    # Model hyperparameters (as before)
    threshold = trial.suggest_float("threshold", 0.01, 0.1)
    slope = trial.suggest_float("slope", 5.0, 20.0)
    beta = trial.suggest_float("beta", 0.8, 0.99)
    p1 = trial.suggest_float("p1", 0.3, 0.7)
    p2 = trial.suggest_float("p2", 0.1, 0.4)

    # Optimizer hyperparameters
    lr = trial.suggest_float("lr", 1e-6, 1e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)

    # Scheduler hyperparameters
    scheduler_factor = trial.suggest_float("scheduler_factor", 0.1, 0.7)
    scheduler_patience = trial.suggest_int("scheduler_patience", 3, 10)

    # Create model and optimizer
    model = STFTSpikeClassifier(
        input_channels=22, threshold=threshold, slope=slope, beta=beta, p1=p1, p2=p2
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay
    )

    # Create scheduler with sampled parameters
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=scheduler_factor,
        patience=scheduler_patience,
        min_lr=1e-6,
    )

    criterion = SF.mse_count_loss()

    # Training loop
    num_epochs = 15  # Reduced for hyperparameter search
    best_val_loss = 0

    print(f"Trial {trial.number} Starting training...")
    print(f"Model Parameters: threshold={threshold}, slope={slope}, beta={beta}, p1={p1}, p2={p2}")
    print(f"Optimizer Parameters: lr={lr}, weight_decay={weight_decay}")
    print(f"Scheduler Parameters: factor={scheduler_factor}, patience={scheduler_patience}")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        for batch_idx, (data, targets) in enumerate(train_loop):
            # Preprocess data
            scaled_data = vectorized_stft(data)

            encoding_method = trial.suggest_categorical(
                "encoding_method", ["rate", "delta"]
            )

            if encoding_method == "rate":
                scaled_data = torch.abs(scaled_data)

                if scaled_data.max() > 0:  # Avoid division by zero
                    scaled_data = scaled_data / scaled_data.max()

                data_spike = spikegen.rate(scaled_data, time_var_input=True)
            elif encoding_method == "delta":
                # Get magnitude with sign from real part
                magnitude = torch.abs(scaled_data)
                sign = torch.sign(scaled_data.real)
                
                # Apply sign to magnitude to preserve direction
                signed_magnitude = magnitude * sign

                signed_magnitude = signed_magnitude / torch.max(magnitude)

                data_spike = spikegen.delta(signed_magnitude)

            data_spike, targets = data_spike.to(device), targets.to(device)

            spk_rec, _ = model(data_spike)

            loss_val = criterion(spk_rec, targets)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            train_loss += loss_val.item()
            spike_sum = torch.sum(spk_rec, dim=0)
            _, predicted = torch.max(spike_sum, 1)
            total_train += targets.size(0)
            correct_train += (predicted == targets).sum().item()

            # Update progress bar
            train_loop.set_postfix(
                loss=train_loss / (batch_idx + 1), acc=100.0 * correct_train / total_train
            )

        # Validation phase
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

            for batch_idx, (data, targets) in enumerate(val_loop):
                # Preprocess data
                scaled_data = vectorized_stft(data)

                encoding_method = trial.suggest_categorical(
                    "encoding_method", ["rate", "delta"]
                )

                if encoding_method == "rate":
                    scaled_data = torch.abs(scaled_data)

                    if scaled_data.max() > 0:  # Avoid division by zero
                        scaled_data = scaled_data / scaled_data.max()

                    data_spike = spikegen.rate(scaled_data, time_var_input=True)
                elif encoding_method == "delta":
                    # Get magnitude with sign from real part
                    magnitude = torch.abs(scaled_data)
                    sign = torch.sign(scaled_data.real)
                    
                    # Apply sign to magnitude to preserve direction
                    signed_magnitude = magnitude * sign

                    signed_magnitude = signed_magnitude / torch.max(magnitude)

                    data_spike = spikegen.delta(signed_magnitude)


                data_spike, targets = data_spike.to(device), targets.to(device)

                spk_rec, _ = model(data_spike)

                loss_val = criterion(spk_rec, targets)

                val_loss += loss_val.item()
                spike_sum = torch.sum(spk_rec, dim=0)
                _, predicted = torch.max(spike_sum, 1)
                total_val += targets.size(0)
                correct_val += (predicted == targets).sum().item()

                # Update progress bar
                val_loop.set_postfix(
                    loss=val_loss / (batch_idx + 1), acc=100.0 * correct_val / total_val
                )

        # Calculate average metrics
        avg_val_loss = val_loss / len(val_loader)

        scheduler.step(avg_val_loss)

        trial.report(avg_val_loss, epoch)

        if avg_val_loss > best_val_loss:
            best_val_loss = avg_val_loss
        
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()


    return avg_val_loss

In [49]:
study_name = "STFT_SNN_Classifier"
study_storage = "sqlite:///classifier_tuning.db"
study = optuna.create_study(direction='minimize', study_name=study_name, storage=study_storage, load_if_exists=True)

[I 2025-03-29 13:09:23,784] A new study created in RDB with name: STFT_SNN_Classifier


Try optimizer

In [53]:
study.optimize(objective, n_trials=4)

Trial 5 Starting training...
Model Parameters: threshold=0.029502183721464022, slope=5.877186513794134, beta=0.8878529652241832, p1=0.5158072293841036, p2=0.33379760536677766
Optimizer Parameters: lr=4.0554645367364053e-05, weight_decay=1.45773874831512e-06
Scheduler Parameters: factor=0.5512718333066357, patience=6


Epoch 1/15 [Train]: 100%|██████████| 110/110 [00:59<00:00,  1.84it/s, acc=71, loss=13.9]  
Epoch 1/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.91it/s, acc=75.4, loss=11.5]
Epoch 2/15 [Train]: 100%|██████████| 110/110 [01:00<00:00,  1.81it/s, acc=78.3, loss=10.6]
Epoch 2/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.88it/s, acc=77.4, loss=10.3]
Epoch 3/15 [Train]: 100%|██████████| 110/110 [01:01<00:00,  1.79it/s, acc=82.1, loss=9.14]
Epoch 3/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.76it/s, acc=81.8, loss=8.77]
Epoch 4/15 [Train]: 100%|██████████| 110/110 [01:01<00:00,  1.78it/s, acc=81.7, loss=9.05]
Epoch 4/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.86it/s, acc=82, loss=8.97]  
Epoch 5/15 [Train]: 100%|██████████| 110/110 [01:01<00:00,  1.80it/s, acc=82.1, loss=8.7] 
Epoch 5/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.89it/s, acc=83, loss=8.07]  
Epoch 6/15 [Train]: 100%|██████████| 110/110 [01:01<00:00,  1.78it/s, acc=83.7, loss=8.33]
Epoch 6/15 [Val]: 1

Trial 6 Starting training...
Model Parameters: threshold=0.054468577834123635, slope=15.223145953482412, beta=0.8469335181205, p1=0.6309073641251619, p2=0.32738101660601415
Optimizer Parameters: lr=1.3707996208326852e-06, weight_decay=1.1648432256981654e-05
Scheduler Parameters: factor=0.36429153668934655, patience=10


Epoch 1/15 [Train]: 100%|██████████| 110/110 [00:59<00:00,  1.84it/s, acc=49.6, loss=27]  
Epoch 1/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.89it/s, acc=53.7, loss=30]  
[I 2025-03-29 15:43:18,593] Trial 6 pruned. 


Trial 7 Starting training...
Model Parameters: threshold=0.03079676758506992, slope=5.530382745445492, beta=0.959057173449269, p1=0.4331110788193915, p2=0.32349814845088487
Optimizer Parameters: lr=3.595027939747839e-06, weight_decay=1.5417379134228117e-05
Scheduler Parameters: factor=0.37218934790231617, patience=10


Epoch 1/15 [Train]: 100%|██████████| 110/110 [01:00<00:00,  1.83it/s, acc=68.9, loss=15.9]
Epoch 1/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.87it/s, acc=75.6, loss=14.1]
[I 2025-03-29 15:44:24,522] Trial 7 pruned. 


Trial 8 Starting training...
Model Parameters: threshold=0.051357653057170015, slope=16.80596182659699, beta=0.8115718799492017, p1=0.5281200883681392, p2=0.3903708933624829
Optimizer Parameters: lr=3.39145932461514e-05, weight_decay=2.1952703721429205e-05
Scheduler Parameters: factor=0.3133760298244207, patience=10


Epoch 1/15 [Train]: 100%|██████████| 110/110 [00:56<00:00,  1.95it/s, acc=61.2, loss=16.1]
Epoch 1/15 [Val]: 100%|██████████| 16/16 [00:05<00:00,  3.14it/s, acc=63.9, loss=16.7]
[I 2025-03-29 15:45:26,146] Trial 8 pruned. 


In [54]:
study.best_params

{'threshold': 0.029502183721464022,
 'slope': 5.877186513794134,
 'beta': 0.8878529652241832,
 'p1': 0.5158072293841036,
 'p2': 0.33379760536677766,
 'lr': 4.0554645367364053e-05,
 'weight_decay': 1.45773874831512e-06,
 'scheduler_factor': 0.5512718333066357,
 'scheduler_patience': 6,
 'encoding_method': 'rate'}