In [1]:
%load_ext autoreload
%autoreload 2
from hmpai.pytorch.generators import SAT1DataLoader
import xarray as xr
from pathlib import Path
import torch
from hmpai.pytorch.models import SAT1Base
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate
from hmpai.pytorch.utilities import DEVICE
from hmpai.normalization import norm_dummy
import random
import numpy as np

In [2]:
random.seed(42)
np.random.seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_dummy)

# eeg_dataset = EegDataset(data)
train_loader = SAT1DataLoader(train_data)
val_loader = SAT1DataLoader(val_data)
test_loader = SAT1DataLoader(test_data)

In [19]:
model = SAT1Base(len(dataset.channels), len(dataset.samples), len(dataset.labels)).to(
    DEVICE
)
loss = torch.nn.CrossEntropyLoss()
opt = torch.optim.NAdam(model.parameters())
epochs = 10

In [20]:
for epoch in range(epochs):
    batch_losses = train(model, train_loader, opt, loss)

    # Shuffle data before next epoch
    train_loader.shuffle()

    val_losses, val_accuracy = validate(model, val_loader, loss)

    print(
        f"Epoch {epoch}, loss: {np.mean(batch_losses)}, val_loss: {np.mean(val_losses)}, val_accuracy: {val_accuracy}"
    )

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 0, loss: 0.8009541091980361, val_loss: 0.5199890367230591, val_accuracy: 0.811


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 1, loss: 0.636851741662901, val_loss: 0.5077349846156276, val_accuracy: 0.816


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 2, loss: 0.5752055038460573, val_loss: 0.5505207806065374, val_accuracy: 0.816


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 3, loss: 0.5448356793561022, val_loss: 0.48351270673530444, val_accuracy: 0.82


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 4, loss: 0.5175450766407436, val_loss: 0.48263640835577126, val_accuracy: 0.829


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 5, loss: 0.49240607940536507, val_loss: 0.5078647947585097, val_accuracy: 0.834


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 6, loss: 0.47144039415187217, val_loss: 0.5277855453746659, val_accuracy: 0.821


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 7, loss: 0.4421960789183723, val_loss: 0.5088404929303393, val_accuracy: 0.832


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 8, loss: 0.433515852137593, val_loss: 0.5280860272901399, val_accuracy: 0.824


  0%|          | 0/757 [00:00<?, ?it/s]

KeyboardInterrupt: 