# Testing the model

## Load the data

In [1]:
import os

import h5py
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split

In [2]:
data_path = "./CHB-MIT/processed"
ictal_path = os.path.join(data_path, "ictal.h5")
interictal_path = os.path.join(data_path, "interictal.h5")

ictal_file = h5py.File(ictal_path, "r")
interictal_file = h5py.File(interictal_path, "r")

print(ictal_file.keys())
print(interictal_file.keys())

<KeysViewHDF5 ['channels', 'data', 'info']>
<KeysViewHDF5 ['channels', 'data', 'info']>


## Convert data to tensor

In [3]:
ictal_np = np.array(ictal_file["data"])
interictal_np = np.array(interictal_file["data"])

ictal_data = torch.tensor(ictal_np, dtype=torch.float32)
interictal_data = torch.tensor(interictal_np, dtype=torch.float32)

print(f"Ictal data shape {ictal_data.shape}")
print(f"Interictal data shape {interictal_data.shape}")

Ictal data shape torch.Size([2509, 22, 2048])
Interictal data shape torch.Size([2509, 22, 2048])


## Create Dataset

In [4]:
class EEGDataset(Dataset):
    def __init__(self, ictal_data, interictal_data):
        # Ensure the data is converted to tensors
        self.data = torch.cat([ictal_data, interictal_data])
        # Labels for ictal and interictal data
        self.labels = torch.cat(
            [
                torch.ones(len(ictal_data)),  # Ictal = 1
                torch.zeros(len(interictal_data)),  # Interictal = 0
            ]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].long()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [5]:
dataset = EEGDataset(ictal_data, interictal_data)
train_dataset, test_dataset, val_dataset = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## Implement Encoder
The encoder will implemented in a functional way

In [6]:
# Function to encode continuous EEG data into spikes using rate coding
def rate_coding(data, num_steps=100, gain=1.0):
    """
    Convert continuous EEG data to spike trains using rate coding

    Args:
        data: EEG data tensor of shape (batch_size, channels, time_steps)
        num_steps: Number of time steps for the spike train
        gain: Scaling factor to control firing rate

    Returns:
        Spike tensor of shape (batch_size, channels, num_steps)
    """
    # Normalize data to [0, 1] range for each channel
    data_min = data.min(dim=2, keepdim=True)[0]
    data_max = data.max(dim=2, keepdim=True)[0]
    data_norm = (data - data_min) / (data_max - data_min + 1e-8)

    # Scale by gain factor
    data_norm = data_norm * gain

    # Create spike tensor
    spike_data = torch.zeros(
        (data.shape[0], data.shape[1], num_steps), device=data.device
    )

    # Generate spikes using Bernoulli sampling
    for t in range(num_steps):
        spike_data[:, :, t] = torch.bernoulli(data_norm.mean(dim=2))

    return spike_data

## Testing the model
From this sample model the data is not time domain but the frequency so it need to do the sfft

In [7]:
import snntorch as snn
import torch
import torch.nn as nn
from snntorch import SConv2dLSTM, spikegen, surrogate
from tqdm import tqdm

In [8]:
# Main training function with data loaders
def train_spiking_eeg_with_loaders(
    train_loader, val_loader, test_loader, num_epochs=10, num_steps=100
):
    """
    Train the Spiking EEG Network using DataLoaders

    Args:
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        test_loader: DataLoader for test data
        num_epochs: Number of training epochs
        num_steps: Number of time steps for spike encoding
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

In [9]:
def vectorized_stft(eeg_data, n_fft=256, hop_length=32, win_length=128):
    """
    Apply STFT to batched EEG data using vectorization

    Parameters:
    -----------
    eeg_data: torch.Tensor
        EEG data with shape (batch, channels, time_steps)

    Returns:
    --------
    stft_output: torch.Tensor
        STFT output with shape (batch, channels, frequency_bins, time_frames)
    """
    batch_size, n_channels, time_steps = eeg_data.shape
    window = torch.hann_window(win_length)

    # Reshape to (batch*channels, time_steps)
    reshaped_data = eeg_data.reshape(-1, time_steps)

    # Apply STFT to all channels at once
    stft = torch.stft(
        reshaped_data,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        return_complex=True,
    )

    # Reshape back to (batch, channels, freq_bins, time_frames)
    freq_bins, time_frames = stft.shape[1], stft.shape[2]
    stft_output = stft.reshape(batch_size, n_channels, freq_bins, time_frames)

    return stft_output

In [10]:
class STFTSpikeClassifier(nn.Module):
    def __init__(self, input_channels=22):
        super().__init__()

        self.thr = 0.05
        slope = 13.42287274232855
        beta = 0.9181805491303656
        p1 = 0.5083664100388336
        p2 = 0.26260898840708335
        spike_grad = surrogate.straight_through_estimator()
        spike_grad2 = surrogate.fast_sigmoid(slope=slope)

        # initialize layers - note input_channels=22 for your STFT data
        self.lstm1 = SConv2dLSTM(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=self.thr,
            spike_grad=spike_grad,
        )
        self.lstm2 = SConv2dLSTM(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=self.thr,
            spike_grad=spike_grad,
        )
        self.lstm3 = snn.SConv2dLSTM(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            max_pool=(2, 1),
            threshold=self.thr,
            spike_grad=spike_grad,
        )

        # Calculate the flattened size based on your frequency dimension (129)
        # After 3 max-pooling layers (each dividing by 2), size becomes: 129 → 64 → 32 → 16
        # For time dimension: 1 (we process one time step at a time)
        self.fc1 = nn.Linear(
            64 * 16 * 1, 512
        )  # Adjust this based on actual output size

        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=self.thr)
        self.dropout1 = nn.Dropout(p1)
        self.fc2 = nn.Linear(512, 2)  # Assuming binary classification
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad2, threshold=self.thr)
        self.dropout2 = nn.Dropout(p2)

    def forward(self, x):
        # x shape: (batch, channels=22, freq=129, time=57)
        batch_size = x.size(0)
        time_steps = x.size(3)

        # Initialize LIF state variables
        mem4 = self.lif1.init_leaky()
        mem5 = self.lif2.init_leaky()
        syn1, mem1 = self.lstm1.init_sconv2dlstm()
        syn2, mem2 = self.lstm2.init_sconv2dlstm()
        syn3, mem3 = self.lstm3.init_sconv2dlstm()

        # Output recording
        spk5_rec = []
        mem5_rec = []

        # Process each time step
        for step in range(time_steps):
            # Extract the current time step and prepare input
            # x_t shape: (batch, channels=22, freq=129, time=1)
            x_t = x[:, :, :, step].unsqueeze(-1)

            # Pass through SConv2dLSTM layers
            spk1, syn1, mem1 = self.lstm1(x_t, syn1, mem1)
            spk2, syn2, mem2 = self.lstm2(spk1, syn2, mem2)
            spk3, syn3, mem3 = self.lstm3(spk2, syn3, mem3)

            # Flatten and feed through fully connected layers
            cur4 = self.dropout1(self.fc1(spk3.flatten(1)))
            spk4, mem4 = self.lif1(cur4, mem4)

            cur5 = self.dropout2(self.fc2(spk4))
            spk5, mem5 = self.lif2(cur5, mem5)

            # Record output spikes and membrane potentials
            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        # Stack time steps
        return torch.stack(spk5_rec), torch.stack(mem5_rec)

In [11]:
import snntorch.functional as SF

# Initialize the network
device = torch.device("cuda")
SNN_net = STFTSpikeClassifier().to(device)
# Loss and optimizer
criterion = SF.mse_count_loss()

optimizer = torch.optim.AdamW(
    SNN_net.parameters(),
    lr=5e-5,  # Slightly higher than your current value but not too high
    betas=(0.9, 0.999),
    weight_decay=1e-5,  # Light regularization
)

# Add a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6
)

In [12]:
# Training loop
loss_hist = []
val_loss_hist = []
best_val_loss = float("inf")
num_epochs = 50

for epoch in range(num_epochs):
    # Training phase
    SNN_net.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch_idx, (data, targets) in enumerate(train_loop):
        # Preprocess data on CPU first
        # STFT output: (batch, channels=22, freq=129, time=57)
        scaled_data = vectorized_stft(data)
        scaled_data = torch.abs(scaled_data)

        # Normalize data to between 0 and 1
        if scaled_data.max() > 0:  # Avoid division by zero
            scaled_data = scaled_data / scaled_data.max()

        # Apply delta encoding - this will encode when values cross threshold
        threshold = 0.1  # Adjust based on your data distribution
        data_spike = spikegen.delta(
            scaled_data, threshold=threshold, padding=False, off_spike=False
        )

        # Move data to device after preprocessing
        data_spike, targets = data_spike.to(device), targets.to(device)

        # Forward pass
        spk_rec, mem_rec = SNN_net(data_spike)

        # Calculate loss using spikes (not just final membrane potential)
        loss_val = criterion(spk_rec, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Update metrics - use spike count for prediction
        train_loss += loss_val.item()

        # Sum spikes across time steps for prediction
        spike_sum = torch.sum(spk_rec, dim=0)
        _, predicted = torch.max(spike_sum, 1)

        total_train += targets.size(0)
        correct_train += (predicted == targets).sum().item()

        # Update progress bar
        train_loop.set_postfix(
            loss=train_loss / (batch_idx + 1), acc=100.0 * correct_train / total_train
        )

        # Store loss
        loss_hist.append(loss_val.item())

    # Validation phase
    SNN_net.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():  # No gradient calculation during validation
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        for batch_idx, (data, targets) in enumerate(val_loop):
            # Preprocess data on CPU first
            scaled_data = vectorized_stft(data)
            scaled_data = torch.abs(scaled_data)

            # Normalize if needed
            if scaled_data.max() > 0:  # Avoid division by zero
                scaled_data = scaled_data / scaled_data.max()

            # Apply delta encoding
            data_spike = spikegen.delta(
                scaled_data, threshold=threshold, padding=False, off_spike=False
            )

            # Move to device
            data_spike, targets = data_spike.to(device), targets.to(device)

            # Forward pass
            spk_rec, mem_rec = SNN_net(data_spike)

            # Calculate loss on spikes
            loss_val = criterion(spk_rec, targets)

            # Update metrics - use spike count for prediction
            val_loss += loss_val.item()

            # Sum spikes across time steps for prediction
            spike_sum = torch.sum(spk_rec, dim=0)
            _, predicted = torch.max(spike_sum, 1)

            total_val += targets.size(0)
            correct_val += (predicted == targets).sum().item()

            # Update progress bar
            val_loop.set_postfix(
                loss=val_loss / (batch_idx + 1), acc=100.0 * correct_val / total_val
            )

        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
        val_loss_hist.append(avg_val_loss)

        # Update learning rate based on validation loss
        scheduler.step(avg_val_loss)

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(SNN_net.state_dict(), "best_spiking_eeg_model_2.pth")
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(
        f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100.*correct_train/total_train:.2f}%"
    )
    print(
        f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100.*correct_val/total_val:.2f}%"
    )
    print("-" * 60)

Epoch 1/50 [Train]: 100%|██████████| 110/110 [01:09<00:00,  1.59it/s, acc=64.7, loss=15]  
Epoch 1/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.02it/s, acc=68.3, loss=14.9]


Saved best model with validation loss: 14.8673
Epoch 1/50:
Train Loss: 14.9815, Train Acc: 64.67%
Val Loss: 14.8673, Val Acc: 68.26%
------------------------------------------------------------


Epoch 2/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.78it/s, acc=68.9, loss=13.9]
Epoch 2/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.14it/s, acc=69.7, loss=14.1]


Saved best model with validation loss: 14.1433
Epoch 2/50:
Train Loss: 13.8507, Train Acc: 68.86%
Val Loss: 14.1433, Val Acc: 69.66%
------------------------------------------------------------


Epoch 3/50 [Train]: 100%|██████████| 110/110 [00:16<00:00,  6.52it/s, acc=69.7, loss=13.6]
Epoch 3/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.39it/s, acc=69.9, loss=14.3]


Epoch 3/50:
Train Loss: 13.6011, Train Acc: 69.71%
Val Loss: 14.2833, Val Acc: 69.86%
------------------------------------------------------------


Epoch 4/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.46it/s, acc=70.3, loss=13.5]
Epoch 4/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.29it/s, acc=71.5, loss=14]  


Saved best model with validation loss: 14.0339
Epoch 4/50:
Train Loss: 13.4758, Train Acc: 70.25%
Val Loss: 14.0339, Val Acc: 71.46%
------------------------------------------------------------


Epoch 5/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.53it/s, acc=70.9, loss=13.1]
Epoch 5/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.08it/s, acc=70.9, loss=12.9]


Saved best model with validation loss: 12.8822
Epoch 5/50:
Train Loss: 13.0836, Train Acc: 70.88%
Val Loss: 12.8822, Val Acc: 70.86%
------------------------------------------------------------


Epoch 6/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.52it/s, acc=71.6, loss=13.1]
Epoch 6/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.73it/s, acc=71.9, loss=12.7]


Saved best model with validation loss: 12.6677
Epoch 6/50:
Train Loss: 13.1262, Train Acc: 71.62%
Val Loss: 12.6677, Val Acc: 71.86%
------------------------------------------------------------


Epoch 7/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.40it/s, acc=72.2, loss=12.6]
Epoch 7/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.05it/s, acc=72.1, loss=14]  


Epoch 7/50:
Train Loss: 12.6216, Train Acc: 72.22%
Val Loss: 13.9867, Val Acc: 72.06%
------------------------------------------------------------


Epoch 8/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.26it/s, acc=72.4, loss=12.7]
Epoch 8/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 10.82it/s, acc=72.7, loss=12.9]


Epoch 8/50:
Train Loss: 12.7398, Train Acc: 72.45%
Val Loss: 12.9470, Val Acc: 72.65%
------------------------------------------------------------


Epoch 9/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.28it/s, acc=73, loss=12.4]  
Epoch 9/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.42it/s, acc=72.7, loss=13.6]


Epoch 9/50:
Train Loss: 12.4177, Train Acc: 72.96%
Val Loss: 13.6220, Val Acc: 72.65%
------------------------------------------------------------


Epoch 10/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.37it/s, acc=73.9, loss=12.2]
Epoch 10/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.92it/s, acc=73.5, loss=12.2]


Saved best model with validation loss: 12.1774
Epoch 10/50:
Train Loss: 12.2070, Train Acc: 73.90%
Val Loss: 12.1774, Val Acc: 73.45%
------------------------------------------------------------


Epoch 11/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.10it/s, acc=74.1, loss=12.2]
Epoch 11/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.76it/s, acc=72.7, loss=11.8]


Saved best model with validation loss: 11.7786
Epoch 11/50:
Train Loss: 12.1571, Train Acc: 74.07%
Val Loss: 11.7786, Val Acc: 72.65%
------------------------------------------------------------


Epoch 12/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.30it/s, acc=73.4, loss=12.3]
Epoch 12/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.31it/s, acc=73.3, loss=12.7]


Epoch 12/50:
Train Loss: 12.3378, Train Acc: 73.38%
Val Loss: 12.7331, Val Acc: 73.25%
------------------------------------------------------------


Epoch 13/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.21it/s, acc=73.7, loss=12.2]
Epoch 13/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 10.79it/s, acc=73.9, loss=11.6]


Saved best model with validation loss: 11.5878
Epoch 13/50:
Train Loss: 12.2176, Train Acc: 73.73%
Val Loss: 11.5878, Val Acc: 73.85%
------------------------------------------------------------


Epoch 14/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.27it/s, acc=73.4, loss=12.3]
Epoch 14/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.74it/s, acc=72.9, loss=11.8]


Epoch 14/50:
Train Loss: 12.2628, Train Acc: 73.38%
Val Loss: 11.8071, Val Acc: 72.85%
------------------------------------------------------------


Epoch 15/50 [Train]: 100%|██████████| 110/110 [00:15<00:00,  7.22it/s, acc=73.7, loss=12.3]
Epoch 15/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.57it/s, acc=73.3, loss=13.2]


Epoch 15/50:
Train Loss: 12.2842, Train Acc: 73.73%
Val Loss: 13.1721, Val Acc: 73.25%
------------------------------------------------------------


Epoch 16/50 [Train]: 100%|██████████| 110/110 [00:20<00:00,  5.40it/s, acc=74.8, loss=11.9]
Epoch 16/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.86it/s, acc=73.3, loss=12.5]


Epoch 16/50:
Train Loss: 11.8804, Train Acc: 74.78%
Val Loss: 12.4940, Val Acc: 73.25%
------------------------------------------------------------


Epoch 17/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.73it/s, acc=74.5, loss=11.9]
Epoch 17/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.12it/s, acc=74.7, loss=12]  


Epoch 17/50:
Train Loss: 11.9344, Train Acc: 74.47%
Val Loss: 11.9832, Val Acc: 74.65%
------------------------------------------------------------


Epoch 18/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.73it/s, acc=73.4, loss=12.1]
Epoch 18/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.03it/s, acc=74.3, loss=12.4]


Epoch 18/50:
Train Loss: 12.1259, Train Acc: 73.41%
Val Loss: 12.3835, Val Acc: 74.25%
------------------------------------------------------------


Epoch 19/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.71it/s, acc=74.3, loss=11.8]
Epoch 19/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.04it/s, acc=73.7, loss=11.8]


Epoch 19/50:
Train Loss: 11.8427, Train Acc: 74.30%
Val Loss: 11.8344, Val Acc: 73.65%
------------------------------------------------------------


Epoch 20/50 [Train]: 100%|██████████| 110/110 [00:18<00:00,  5.92it/s, acc=74.5, loss=12]  
Epoch 20/50 [Val]: 100%|██████████| 16/16 [00:44<00:00,  2.80s/it, acc=73.9, loss=12.8]


Epoch 20/50:
Train Loss: 11.9969, Train Acc: 74.49%
Val Loss: 12.7548, Val Acc: 73.85%
------------------------------------------------------------


Epoch 21/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.84it/s, acc=74.2, loss=11.9]
Epoch 21/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.57it/s, acc=73.9, loss=11.9]


Epoch 21/50:
Train Loss: 11.8816, Train Acc: 74.15%
Val Loss: 11.9119, Val Acc: 73.85%
------------------------------------------------------------


Epoch 22/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.69it/s, acc=75.4, loss=11.9]
Epoch 22/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.45it/s, acc=74.1, loss=12]  


Epoch 22/50:
Train Loss: 11.8525, Train Acc: 75.38%
Val Loss: 12.0236, Val Acc: 74.05%
------------------------------------------------------------


Epoch 23/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.78it/s, acc=75, loss=11.7]  
Epoch 23/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.27it/s, acc=75.6, loss=11.5]


Saved best model with validation loss: 11.4580
Epoch 23/50:
Train Loss: 11.6656, Train Acc: 74.95%
Val Loss: 11.4580, Val Acc: 75.65%
------------------------------------------------------------


Epoch 24/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.77it/s, acc=73.7, loss=11.8]
Epoch 24/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.09it/s, acc=75, loss=11.8]  


Epoch 24/50:
Train Loss: 11.8055, Train Acc: 73.67%
Val Loss: 11.7886, Val Acc: 75.05%
------------------------------------------------------------


Epoch 25/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.78it/s, acc=74.9, loss=11.7]
Epoch 25/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.38it/s, acc=75.2, loss=11.8]


Epoch 25/50:
Train Loss: 11.7129, Train Acc: 74.86%
Val Loss: 11.8039, Val Acc: 75.25%
------------------------------------------------------------


Epoch 26/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.60it/s, acc=74.3, loss=11.9]
Epoch 26/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.57it/s, acc=74.5, loss=12]  


Epoch 26/50:
Train Loss: 11.8639, Train Acc: 74.27%
Val Loss: 12.0264, Val Acc: 74.45%
------------------------------------------------------------


Epoch 27/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.72it/s, acc=73.7, loss=11.8]
Epoch 27/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.36it/s, acc=74.1, loss=12.4]


Epoch 27/50:
Train Loss: 11.7953, Train Acc: 73.67%
Val Loss: 12.3827, Val Acc: 74.05%
------------------------------------------------------------


Epoch 28/50 [Train]: 100%|██████████| 110/110 [00:16<00:00,  6.50it/s, acc=74.9, loss=11.5]
Epoch 28/50 [Val]: 100%|██████████| 16/16 [00:01<00:00,  9.82it/s, acc=75, loss=11.3]  


Saved best model with validation loss: 11.2944
Epoch 28/50:
Train Loss: 11.5457, Train Acc: 74.89%
Val Loss: 11.2944, Val Acc: 75.05%
------------------------------------------------------------


Epoch 29/50 [Train]: 100%|██████████| 110/110 [00:16<00:00,  6.63it/s, acc=75.3, loss=11.3]
Epoch 29/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.64it/s, acc=74.9, loss=12]  


Epoch 29/50:
Train Loss: 11.2995, Train Acc: 75.35%
Val Loss: 12.0312, Val Acc: 74.85%
------------------------------------------------------------


Epoch 30/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.44it/s, acc=74.6, loss=11.8]
Epoch 30/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.16it/s, acc=74.7, loss=12.1]


Epoch 30/50:
Train Loss: 11.8284, Train Acc: 74.61%
Val Loss: 12.0997, Val Acc: 74.65%
------------------------------------------------------------


Epoch 31/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.52it/s, acc=74.9, loss=11.4]
Epoch 31/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.93it/s, acc=74.5, loss=11.5]


Epoch 31/50:
Train Loss: 11.4434, Train Acc: 74.86%
Val Loss: 11.5307, Val Acc: 74.45%
------------------------------------------------------------


Epoch 32/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.50it/s, acc=74, loss=11.6]  
Epoch 32/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.33it/s, acc=74.1, loss=11.8]


Epoch 32/50:
Train Loss: 11.6300, Train Acc: 73.98%
Val Loss: 11.7915, Val Acc: 74.05%
------------------------------------------------------------


Epoch 33/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.81it/s, acc=75.4, loss=11.4]
Epoch 33/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.39it/s, acc=73.9, loss=11.8]


Epoch 33/50:
Train Loss: 11.3851, Train Acc: 75.41%
Val Loss: 11.8426, Val Acc: 73.85%
------------------------------------------------------------


Epoch 34/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.70it/s, acc=75.4, loss=11.5]
Epoch 34/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.84it/s, acc=74.9, loss=11.9]


Epoch 34/50:
Train Loss: 11.5419, Train Acc: 75.38%
Val Loss: 11.9076, Val Acc: 74.85%
------------------------------------------------------------


Epoch 35/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.74it/s, acc=75.8, loss=11.4]
Epoch 35/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.76it/s, acc=74.5, loss=11.6]


Epoch 35/50:
Train Loss: 11.3965, Train Acc: 75.78%
Val Loss: 11.6353, Val Acc: 74.45%
------------------------------------------------------------


Epoch 36/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.75it/s, acc=75.2, loss=11.5]
Epoch 36/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.49it/s, acc=74.5, loss=11.7]


Epoch 36/50:
Train Loss: 11.5359, Train Acc: 75.21%
Val Loss: 11.6687, Val Acc: 74.45%
------------------------------------------------------------


Epoch 37/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.59it/s, acc=75, loss=11.2]  
Epoch 37/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.82it/s, acc=74.7, loss=11.8]


Epoch 37/50:
Train Loss: 11.2179, Train Acc: 74.95%
Val Loss: 11.8181, Val Acc: 74.65%
------------------------------------------------------------


Epoch 38/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.60it/s, acc=76.3, loss=11.3]
Epoch 38/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.76it/s, acc=74.5, loss=12]  


Epoch 38/50:
Train Loss: 11.2679, Train Acc: 76.29%
Val Loss: 12.0106, Val Acc: 74.45%
------------------------------------------------------------


Epoch 39/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.39it/s, acc=76.1, loss=11.2]
Epoch 39/50 [Val]: 100%|██████████| 16/16 [00:02<00:00,  7.88it/s, acc=75.2, loss=11.4]


Epoch 39/50:
Train Loss: 11.2395, Train Acc: 76.06%
Val Loss: 11.3543, Val Acc: 75.25%
------------------------------------------------------------


Epoch 40/50 [Train]: 100%|██████████| 110/110 [00:14<00:00,  7.80it/s, acc=75.2, loss=11.4]
Epoch 40/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 11.56it/s, acc=73.1, loss=12.2]


Epoch 40/50:
Train Loss: 11.3709, Train Acc: 75.18%
Val Loss: 12.1897, Val Acc: 73.05%
------------------------------------------------------------


Epoch 41/50 [Train]: 100%|██████████| 110/110 [00:16<00:00,  6.48it/s, acc=75.5, loss=11.4]
Epoch 41/50 [Val]: 100%|██████████| 16/16 [00:01<00:00, 12.17it/s, acc=74.7, loss=11.7]


Epoch 41/50:
Train Loss: 11.4070, Train Acc: 75.46%
Val Loss: 11.6798, Val Acc: 74.65%
------------------------------------------------------------


Epoch 42/50 [Train]: 100%|██████████| 110/110 [00:28<00:00,  3.83it/s, acc=75.6, loss=11.3]
Epoch 42/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.57it/s, acc=74.3, loss=11.9]


Epoch 42/50:
Train Loss: 11.3256, Train Acc: 75.58%
Val Loss: 11.8962, Val Acc: 74.25%
------------------------------------------------------------


Epoch 43/50 [Train]: 100%|██████████| 110/110 [00:39<00:00,  2.79it/s, acc=75.5, loss=11.5]
Epoch 43/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.25it/s, acc=74.5, loss=11.9]


Epoch 43/50:
Train Loss: 11.4581, Train Acc: 75.49%
Val Loss: 11.9453, Val Acc: 74.45%
------------------------------------------------------------


Epoch 44/50 [Train]: 100%|██████████| 110/110 [00:39<00:00,  2.75it/s, acc=75.1, loss=11.4]
Epoch 44/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.34it/s, acc=75.8, loss=11.3]


Epoch 44/50:
Train Loss: 11.4271, Train Acc: 75.06%
Val Loss: 11.3495, Val Acc: 75.85%
------------------------------------------------------------


Epoch 45/50 [Train]: 100%|██████████| 110/110 [00:37<00:00,  2.93it/s, acc=75.8, loss=11.3]
Epoch 45/50 [Val]: 100%|██████████| 16/16 [00:05<00:00,  3.19it/s, acc=74.9, loss=12]  


Epoch 45/50:
Train Loss: 11.3490, Train Acc: 75.78%
Val Loss: 11.9806, Val Acc: 74.85%
------------------------------------------------------------


Epoch 46/50 [Train]: 100%|██████████| 110/110 [00:39<00:00,  2.76it/s, acc=75.4, loss=11.4]
Epoch 46/50 [Val]: 100%|██████████| 16/16 [00:09<00:00,  1.62it/s, acc=74.9, loss=11.4]


Epoch 46/50:
Train Loss: 11.4372, Train Acc: 75.43%
Val Loss: 11.4173, Val Acc: 74.85%
------------------------------------------------------------


Epoch 47/50 [Train]: 100%|██████████| 110/110 [00:40<00:00,  2.72it/s, acc=75.5, loss=11.2]
Epoch 47/50 [Val]: 100%|██████████| 16/16 [00:07<00:00,  2.19it/s, acc=75, loss=11.7]  


Epoch 47/50:
Train Loss: 11.2335, Train Acc: 75.49%
Val Loss: 11.7030, Val Acc: 75.05%
------------------------------------------------------------


Epoch 48/50 [Train]: 100%|██████████| 110/110 [00:40<00:00,  2.74it/s, acc=74.7, loss=11.6]
Epoch 48/50 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.80it/s, acc=75.2, loss=11.6]


Epoch 48/50:
Train Loss: 11.6326, Train Acc: 74.72%
Val Loss: 11.5595, Val Acc: 75.25%
------------------------------------------------------------


Epoch 49/50 [Train]: 100%|██████████| 110/110 [00:37<00:00,  2.91it/s, acc=75.9, loss=11.2]
Epoch 49/50 [Val]: 100%|██████████| 16/16 [00:04<00:00,  3.49it/s, acc=75.4, loss=11.5]


Epoch 49/50:
Train Loss: 11.1632, Train Acc: 75.89%
Val Loss: 11.5052, Val Acc: 75.45%
------------------------------------------------------------


Epoch 50/50 [Train]: 100%|██████████| 110/110 [00:37<00:00,  2.90it/s, acc=75.7, loss=11.3]
Epoch 50/50 [Val]: 100%|██████████| 16/16 [00:05<00:00,  2.97it/s, acc=75.4, loss=11.5]

Epoch 50/50:
Train Loss: 11.2939, Train Acc: 75.66%
Val Loss: 11.5286, Val Acc: 75.45%
------------------------------------------------------------



