In [None]:
%pwd

In [None]:
from torcheeg.datasets import DEAPDataset
from torcheeg import transforms
import torch
from src.utils.transforms import DeapAVToStress

from torcheeg.datasets.constants.emotion_recognition.deap import \
    DEAP_CHANNEL_LOCATION_DICT

dataset = DEAPDataset(
    io_path=f'.data_cache/deap',
    root_path='.data/DEAP/data_preprocessed_python-002/data_preprocessed_python',
    online_transform=transforms.Compose([
        transforms.To2d(),
        transforms.ToTensor()
    ]),
    label_transform=transforms.Compose([
        transforms.Select(['arousal','valence']),
        # transforms.Binary(5.0),
        # transforms.BinariesToCategory(),
        DeapAVToStress(thresholds=[
            [7.5, 2.5],
            [5.0, 5.0]]),
    ]),
    num_worker=8)

In [None]:
train_val_test_split = [0.6, 0.2, 0.2]

train_size = int(train_val_test_split[0] * len(dataset))
val_size = int(train_val_test_split[1] * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

In [None]:
from torch.utils.data import DataLoader
batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 2,
    pin_memory = True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle = False,
    num_workers = 2,
    pin_memory = True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle = False,
    num_workers = 2,
    pin_memory = True,
)

In [None]:
from torch.optim import SGD, Adam
from mmengine.runner import Runner
from src.models.torcheeg_mmwraper import MMEEGNet
from src.utils.metrics import AccuracyWithLoss, Accuracy, ConfusionMatrix


runner = Runner(
    # the model used for training and validation.
    # Needs to meet specific interface requirements

    model=MMEEGNet(chunk_size=128,
                num_electrodes=32,
                dropout=0.5,
                kernel_1=64,
                kernel_2=16,
                F1=8,
                F2=16,
                D=2,
                num_classes=3).float(),

    # working directory which saves training logs and weight files
    work_dir='./.exp/new_deap_stress/eegnet',
    
    # train dataloader needs to meet the PyTorch data loader protocol
    train_dataloader=train_loader,
    # optimize wrapper for optimization with additional features like
    # AMP, gradtient accumulation, etc
    optim_wrapper=dict(optimizer=dict(type=Adam, 
                                      lr=0.001,
                                      betas=(0.9, 0.999),
                                      eps=1e-08)),
    # trainging coinfs for specifying training epoches, verification intervals, etc
    train_cfg=dict(by_epoch=True, 
                   max_epochs=200, 
                   val_interval=1),
    
    
    # validation dataloader also needs to meet the PyTorch data loader protocol
    val_dataloader=val_loader,
    # validation configs for specifying additional parameters required for validation
    val_cfg=dict(),
    # validation evaluator. The default one is used here
    val_evaluator=dict(type=AccuracyWithLoss),

    # test dataloader also needs to meet the PyTorch data loader protocol
    test_dataloader=test_loader,
    # test configs for specifying additional parameters required for testing
    test_cfg=dict(),
    # test evaluator. The default one is used here
    test_evaluator=[dict(type=Accuracy), dict(type=ConfusionMatrix, num_classes=3)],

    visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]),
    
    resume=True,
)

In [None]:
runner.train()

In [None]:
# Run trained model on test set
test_results = runner.test()

In [None]:
test_results

In [None]:
ConfusionMatrix().plot(test_results['confusion_matrix/result'], classes=['No Stress', 'Slight Stress', 'High Stress'], include_values=True, normalize=True)