In [1]:
# !wget https://www.bbci.de/competition/download/competition_iv/BCICIV_1_mat.zip
# !unzip -q BCICIV_1_mat.zip
# !rm -rf BCICIV_1_mat.zip

In [2]:
# !git clone https://github.com/XJTU-EEG/LibEER.git

In [3]:
import os
import gc
import math
import random
import scipy.io
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import sys
sys.path.append('LibEER/LibEER')
sys.path.append('LibEER/LibEER/models')

In [4]:
from RGNN import RGNN

In [5]:
TRAIN = ['a', 'b', 'd', 'e', 'g']
VALID = ['c', 'f']

In [6]:
# a - left, foot
# b - left, right
# c - left, right
# d - left, right
# e - left, right
# f - left, foot
# g - left, right

In [7]:
LABELS = {
    'idle': 0,
    'left' : 1,
    'right': 2,
    'foot': 3
}

LR = 1e-2
EPOCHS = 1000
BATCH_SIZE = 64

In [8]:
def get_labels(data):
    N = len(data['cnt'])

    labels = np.zeros((N, 4), dtype=np.uint8)
    labels[:, 0] = 1

    cls_labels = [d[0] for d in data['nfo'][0][0][1][0].tolist()]
    timestamps = data['mrk'][0][0][0]
    cls_idx = data['mrk'][0][0][1]

    t1 = timestamps[np.where(cls_idx == 1)]
    l1 = [0]*4
    l1[LABELS[cls_labels[0]]] = 1
    l1 = np.asarray([l1]*100)
    for t in t1:
        labels[t-50:t+50] = l1 # change to t :t+100

    t2 = timestamps[np.where(cls_idx == -1)]
    l2 = [0]*4
    l2[LABELS[cls_labels[1]]] = 1
    l2 = np.asarray([l2]*100)
    for t in t2:
        labels[t-50:t+50] = l2

    return labels

def get_data(data_id, split_size):
    data = scipy.io.loadmat(f'BCICIV_calib_ds1{data_id}.mat')
    eegs = data['cnt']
    labels = get_labels(data)
    np.savetxt(f"labels_{data_id}.csv", labels, delimiter=",", fmt="%d")
    
    splits = range(split_size, len(eegs), split_size)

    eegs_split = np.array_split(eegs, splits)[:-1]
    labels_split = np.array_split(labels, splits)[:-1]
    return eegs_split, labels_split

In [9]:
train_eegs, train_labels, valid_eegs, valid_labels = [], [], [], []

for train_id in tqdm(TRAIN):
    # print(train_id)
    eegs_split, labels_split = get_data(train_id, split_size=200)
    train_eegs.extend(eegs_split)
    train_labels.extend(labels_split)

for valid_id in tqdm(VALID):
    eegs_split, labels_split = get_data(valid_id, split_size=100)
    valid_eegs.extend(eegs_split)
    valid_labels.extend(labels_split)

100%|██████████| 5/5 [00:00<00:00,  7.30it/s]
100%|██████████| 2/2 [00:00<00:00,  7.34it/s]


In [10]:
class EEGDataset(Dataset):
    def __init__(self, eegs, labels, split):
        self.eegs = eegs
        self.labels = labels
        self.split = split

    def __getitem__(self, i):
        eeg, label = self.eegs[i], self.labels[i]
        if self.split == 'train':
            start_idx = random.randint(0, 100)
            eeg = eeg[start_idx:start_idx+100]
            label = label[start_idx:start_idx+100]

        eeg = torch.from_numpy(eeg).float()
        #eeg = (eeg - MEAN)/STD
        eeg = (eeg - torch.mean(eeg, dim=0))/torch.std(eeg, dim=0)
        eeg = eeg.transpose(0, 1)

        # print(f"Single sample shape: {eeg.shape}")
        
        label_soft = torch.from_numpy(label).float().mean(dim=0)
        #label_hard = torch.argmax(label_soft)
        return eeg, label_soft

    def __len__(self):
        return len(self.eegs)

In [11]:
dataset_train = EEGDataset(train_eegs, train_labels, 'train')
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=True, pin_memory=True, drop_last=True)
n_train = len(dataloader_train)

dataset_valid = EEGDataset(valid_eegs, valid_labels, 'valid')
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=False, pin_memory=True, drop_last=False)
n_valid = len(dataloader_valid)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

In [13]:
model = RGNN(num_electrodes=59, in_channels=100, num_classes=4, num_layers=2, num_hidden=512, noise_level=0.1,
                 dropout=0.5, domain_adaptation=False, prior_known_init=True)

print(f'Number of parameters: {count_parameters(model):.2f}M')
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")

# Convert model parameters to float32 before moving to device
model = model.float()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
loss_name = "Cross_Entropy"

# Calculate class weights based on label distribution
# labels = torch.stack([label for _, label in dataset_train]).float()  # Ensure float32
# label_counts = torch.bincount(torch.argmax(labels, dim=1))
# class_weights = 1.0 / label_counts
# class_weights = class_weights / class_weights.sum()
# class_weights = class_weights.float().to(device)  # Ensure float32

# criterion = nn.CrossEntropyLoss(weight=class_weights)
# loss_name = "Weighted_Cross_Entropy"

#grad_scaler = torch.amp.GradScaler('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=[0.9, 0.999], weight_decay=0.001)


#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
#                                                steps_per_epoch=10, epochs=EPOCHS//10,
#                                                pct_start=0.1)


config/model_param/RGNN.yaml may not exist or not available
RGNN Model, Parameters:

num layers:                                                     2
num_hidden:                                                   512
dropout:                                                      0.5
domain_adaptation:                                              0

prior_known_init:                                               1
Not Using Default Setting, the performance may be not the best

When you run subject_independent setting, you should set the domain_adaptation to True

num_hidden, noise_level should be set by the dataset
Starting......
Number of parameters: 15.52M


In [14]:
for x, y in dataloader_train:
    print("Input shape:", x.shape)
    print("Label shape:", y.shape)
    break

Input shape: torch.Size([64, 59, 100])
Label shape: torch.Size([64, 4])


In [15]:
for name, param in model.named_parameters():
    print(f"{name}: {param.mean().item():.4f}")

edge_weight: 0.1028
sgc.w.weight: -0.0000
sgc.w.bias: 0.0000
fc.weight: -0.0000
fc.bias: 0.0000
fc2.weight: 0.0009
fc2.bias: 0.0000


In [16]:
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# Hyperparameters
patience = 10
best_val_loss = float('inf')
counter = 0

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

for epoch in range(1, EPOCHS + 1):
    model.train()
    cur_lr = f"LR : {optimizer.param_groups[0]['lr']:.2E}"

    pbar_train = tqdm(dataloader_train, total=n_train, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    mloss_train, mloss_val = 0.0, 0.0

    for i, (x, y) in enumerate(pbar_train):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        # Get model output and ensure it's a tensor, not tuple
        output = model(x)
        if isinstance(output, tuple):
            y_hat = output[0]  # Take first element if tuple
        else:
            y_hat = output

        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        mloss_train += loss.item()

        gpu_mem = (
            f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB" if torch.cuda.is_available()
            else "MPS Active" if torch.backends.mps.is_available()
            else "CPU Mode"
        )

        pbar_train.set_description(f"Epoch {epoch}/{EPOCHS}  {gpu_mem}  {cur_lr}  Loss: {mloss_train / (i + 1):.4f}")

    # Validation
    model.eval()
    pbar_val = tqdm(dataloader_valid, total=n_valid, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")

    y_true, y_preds = [], []
    for i, (x, y) in enumerate(pbar_val):
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            output = model(x)
            if isinstance(output, tuple):
                y_hat = output[0]
            else:
                y_hat = output

        loss = criterion(y_hat, y)
        mloss_val += loss.item()
        y_preds.append(y_hat)
        y_true.append(y)

        pbar_val.set_description(f"Val Loss: {mloss_val / (i + 1):.4f}")

    y_true = torch.cat(y_true).cpu().numpy()
    y_true = np.argmax(y_true, axis=1)
    y_preds = F.softmax(torch.cat(y_preds), dim=1).argmax(dim=1).cpu().numpy()

    accuracy = np.mean(y_true == y_preds)  
    mets = precision_recall_fscore_support(y_true, y_preds, labels=[0, 1, 2, 3], average=None)

    print(f'Validation Accuracy: {accuracy:.4f}')
    np.savetxt("validation_predictions.csv", y_preds, delimiter=",", fmt="%d")

    print(f'Precision: {[f"{p:.4f}" for p in mets[0]]}')
    print(f'Recall: {[f"{r:.4f}" for r in mets[1]]}')
    print(f'FScore: {[f"{f:.4f}" for f in mets[2]]}')

    # Early Stopping & Save Best Model
    avg_val_loss = mloss_val / len(dataloader_valid)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(model.state_dict(), f"RGNN_{loss_name}_best_model.pth")
        print("Best model saved!")
    else:
        counter += 1
        print(f"Early stopping counter: {counter}/{patience}")

    if counter >= patience:
        print("Validation loss has increased for multiple epochs, stopping early!")
        break

    # Adjust learning rate
    scheduler.step(avg_val_loss)


Epoch 1/1000  MPS Active  LR : 1.00E-02  Loss: 4.3786: 100%|██████████| 74/74 [00:02<00:00, 34.06it/s] 
Val Loss: 0.5341: 100%|██████████| 60/60 [00:00<00:00, 64.01it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8932
Precision: ['0.8948', '0.0000', '0.0000', '0.0000']
Recall: ['0.9979', '0.0000', '0.0000', '0.0000']
FScore: ['0.9436', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 2/1000  MPS Active  LR : 1.00E-02  Loss: 0.5227: 100%|██████████| 74/74 [00:01<00:00, 41.44it/s]
Val Loss: 0.5040: 100%|██████████| 60/60 [00:00<00:00, 68.59it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8943
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['0.9991', '0.0000', '0.0000', '0.0000']
FScore: ['0.9442', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 3/1000  MPS Active  LR : 1.00E-02  Loss: 0.4945: 100%|██████████| 74/74 [00:01<00:00, 39.01it/s]
Val Loss: 0.4952: 100%|██████████| 60/60 [00:00<00:00, 72.56it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 4/1000  MPS Active  LR : 1.00E-02  Loss: 0.4983: 100%|██████████| 74/74 [00:01<00:00, 40.65it/s]
Val Loss: 0.4774: 100%|██████████| 60/60 [00:00<00:00, 72.96it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 5/1000  MPS Active  LR : 1.00E-02  Loss: 0.4820: 100%|██████████| 74/74 [00:01<00:00, 38.52it/s]
Val Loss: 0.4657: 100%|██████████| 60/60 [00:00<00:00, 68.56it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 6/1000  MPS Active  LR : 1.00E-02  Loss: 0.4776: 100%|██████████| 74/74 [00:01<00:00, 41.07it/s]
Val Loss: 0.4687: 100%|██████████| 60/60 [00:00<00:00, 73.87it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 7/1000  MPS Active  LR : 1.00E-02  Loss: 0.4771: 100%|██████████| 74/74 [00:01<00:00, 41.08it/s]
Val Loss: 0.4671: 100%|██████████| 60/60 [00:00<00:00, 66.41it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 8/1000  MPS Active  LR : 1.00E-02  Loss: 0.4660: 100%|██████████| 74/74 [00:01<00:00, 40.60it/s]
Val Loss: 0.4700: 100%|██████████| 60/60 [00:00<00:00, 63.42it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 3/10


Epoch 9/1000  MPS Active  LR : 1.00E-02  Loss: 0.4655: 100%|██████████| 74/74 [00:01<00:00, 41.60it/s]
Val Loss: 0.4849: 100%|██████████| 60/60 [00:00<00:00, 74.04it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 4/10


Epoch 10/1000  MPS Active  LR : 5.00E-03  Loss: 0.4689: 100%|██████████| 74/74 [00:01<00:00, 42.49it/s]
Val Loss: 0.4649: 100%|██████████| 60/60 [00:00<00:00, 69.87it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 11/1000  MPS Active  LR : 5.00E-03  Loss: 0.4626: 100%|██████████| 74/74 [00:01<00:00, 41.08it/s]
Val Loss: 0.4728: 100%|██████████| 60/60 [00:00<00:00, 70.13it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 12/1000  MPS Active  LR : 5.00E-03  Loss: 0.4697: 100%|██████████| 74/74 [00:01<00:00, 41.23it/s]
Val Loss: 0.4715: 100%|██████████| 60/60 [00:00<00:00, 71.38it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 13/1000  MPS Active  LR : 5.00E-03  Loss: 0.4573: 100%|██████████| 74/74 [00:01<00:00, 39.78it/s]
Val Loss: 0.4655: 100%|██████████| 60/60 [00:00<00:00, 69.29it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 3/10


Epoch 14/1000  MPS Active  LR : 5.00E-03  Loss: 0.4558: 100%|██████████| 74/74 [00:01<00:00, 40.53it/s]
Val Loss: 0.4647: 100%|██████████| 60/60 [00:00<00:00, 69.78it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 15/1000  MPS Active  LR : 5.00E-03  Loss: 0.4714: 100%|██████████| 74/74 [00:01<00:00, 41.59it/s]
Val Loss: 0.4665: 100%|██████████| 60/60 [00:00<00:00, 68.81it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 16/1000  MPS Active  LR : 5.00E-03  Loss: 0.4557: 100%|██████████| 74/74 [00:01<00:00, 40.49it/s]
Val Loss: 0.4670: 100%|██████████| 60/60 [00:00<00:00, 73.34it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 17/1000  MPS Active  LR : 5.00E-03  Loss: 0.4585: 100%|██████████| 74/74 [00:01<00:00, 41.05it/s]
Val Loss: 0.4626: 100%|██████████| 60/60 [00:00<00:00, 69.44it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 18/1000  MPS Active  LR : 5.00E-03  Loss: 0.4535: 100%|██████████| 74/74 [00:01<00:00, 41.71it/s]
Val Loss: 0.4632: 100%|██████████| 60/60 [00:00<00:00, 73.42it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 19/1000  MPS Active  LR : 5.00E-03  Loss: 0.4576: 100%|██████████| 74/74 [00:01<00:00, 41.56it/s]
Val Loss: 0.4700: 100%|██████████| 60/60 [00:00<00:00, 70.85it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 20/1000  MPS Active  LR : 5.00E-03  Loss: 0.4546: 100%|██████████| 74/74 [00:01<00:00, 41.39it/s]
Val Loss: 0.4596: 100%|██████████| 60/60 [00:00<00:00, 69.87it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 21/1000  MPS Active  LR : 5.00E-03  Loss: 0.4617: 100%|██████████| 74/74 [00:01<00:00, 41.01it/s]
Val Loss: 0.4617: 100%|██████████| 60/60 [00:00<00:00, 72.88it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 22/1000  MPS Active  LR : 5.00E-03  Loss: 0.4536: 100%|██████████| 74/74 [00:01<00:00, 40.84it/s]
Val Loss: 0.4588: 100%|██████████| 60/60 [00:00<00:00, 66.97it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 23/1000  MPS Active  LR : 5.00E-03  Loss: 0.4538: 100%|██████████| 74/74 [00:01<00:00, 41.93it/s]
Val Loss: 0.4664: 100%|██████████| 60/60 [00:00<00:00, 74.81it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 24/1000  MPS Active  LR : 5.00E-03  Loss: 0.4508: 100%|██████████| 74/74 [00:01<00:00, 38.03it/s]
Val Loss: 0.4680: 100%|██████████| 60/60 [00:00<00:00, 70.55it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 25/1000  MPS Active  LR : 5.00E-03  Loss: 0.4560: 100%|██████████| 74/74 [00:02<00:00, 33.10it/s]
Val Loss: 0.4690: 100%|██████████| 60/60 [00:00<00:00, 71.17it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 3/10


Epoch 26/1000  MPS Active  LR : 5.00E-03  Loss: 0.4541: 100%|██████████| 74/74 [00:01<00:00, 41.35it/s]
Val Loss: 0.4559: 100%|██████████| 60/60 [00:00<00:00, 67.75it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Best model saved!


Epoch 27/1000  MPS Active  LR : 5.00E-03  Loss: 0.4600: 100%|██████████| 74/74 [00:01<00:00, 41.63it/s]
Val Loss: 0.4579: 100%|██████████| 60/60 [00:00<00:00, 68.60it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 28/1000  MPS Active  LR : 5.00E-03  Loss: 0.4458: 100%|██████████| 74/74 [00:02<00:00, 36.82it/s]
Val Loss: 0.4655: 100%|██████████| 60/60 [00:00<00:00, 66.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 29/1000  MPS Active  LR : 5.00E-03  Loss: 0.4574: 100%|██████████| 74/74 [00:02<00:00, 36.39it/s]
Val Loss: 0.4660: 100%|██████████| 60/60 [00:00<00:00, 69.83it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 3/10


Epoch 30/1000  MPS Active  LR : 5.00E-03  Loss: 0.4544: 100%|██████████| 74/74 [00:02<00:00, 36.00it/s]
Val Loss: 0.4654: 100%|██████████| 60/60 [00:00<00:00, 70.65it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 4/10


Epoch 31/1000  MPS Active  LR : 2.50E-03  Loss: 0.4470: 100%|██████████| 74/74 [00:01<00:00, 38.95it/s]
Val Loss: 0.4573: 100%|██████████| 60/60 [00:00<00:00, 63.25it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 5/10


Epoch 32/1000  MPS Active  LR : 2.50E-03  Loss: 0.4555: 100%|██████████| 74/74 [00:01<00:00, 39.12it/s]
Val Loss: 0.4613: 100%|██████████| 60/60 [00:00<00:00, 74.10it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 6/10


Epoch 33/1000  MPS Active  LR : 2.50E-03  Loss: 0.4616: 100%|██████████| 74/74 [00:01<00:00, 40.52it/s]
Val Loss: 0.4695: 100%|██████████| 60/60 [00:00<00:00, 67.36it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 7/10


Epoch 34/1000  MPS Active  LR : 2.50E-03  Loss: 0.4493: 100%|██████████| 74/74 [00:01<00:00, 40.71it/s]
Val Loss: 0.4719: 100%|██████████| 60/60 [00:00<00:00, 68.28it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 8/10


Epoch 35/1000  MPS Active  LR : 1.25E-03  Loss: 0.4495: 100%|██████████| 74/74 [00:01<00:00, 41.47it/s]
Val Loss: 0.4601: 100%|██████████| 60/60 [00:00<00:00, 65.24it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 9/10


Epoch 36/1000  MPS Active  LR : 1.25E-03  Loss: 0.4455: 100%|██████████| 74/74 [00:01<00:00, 40.17it/s]
Val Loss: 0.4665: 100%|██████████| 60/60 [00:00<00:00, 69.07it/s]

Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 10/10
Validation loss has increased for multiple epochs, stopping early!



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
