In [1]:
%load_ext autoreload
%autoreload 2
from hmpai.pytorch.generators import SAT1DataLoader
import xarray as xr
from pathlib import Path
import torch
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate
from hmpai.pytorch.utilities import DEVICE
from hmpai.normalization import *
import random
import numpy as np
from torchinfo import summary

In [2]:
random.seed(42)
np.random.seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_0_to_1)

# eeg_dataset = EegDataset(data)
train_loader = SAT1DataLoader(train_data)
val_loader = SAT1DataLoader(val_data)
test_loader = SAT1DataLoader(test_data)

In [3]:
# model = SAT1Base(len(dataset.channels), len(dataset.samples), len(dataset.labels)).to(
#     DEVICE
# )
model = SAT1Mlp(len(dataset.channels), len(dataset.samples), len(dataset.labels)).to(
    DEVICE
)
loss = torch.nn.CrossEntropyLoss()
opt = torch.optim.NAdam(model.parameters())
epochs = 10

In [32]:
summary(model, (16, 1, 147, 30))

Layer (type:depth-idx)                   Output Shape              Param #
SAT1Mlp                                  [16, 5]                   --
├─Flatten: 1-1                           [16, 4410]                --
├─Linear: 1-2                            [16, 256]                 1,129,216
├─ReLU: 1-3                              [16, 256]                 --
├─Linear: 1-4                            [16, 5]                   1,285
Total params: 1,130,501
Trainable params: 1,130,501
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 18.09
Input size (MB): 0.28
Forward/backward pass size (MB): 0.03
Params size (MB): 4.52
Estimated Total Size (MB): 4.84

In [4]:
for epoch in range(epochs):
    batch_losses = train(model, train_loader, opt, loss)

    # Shuffle data before next epoch
    train_loader.shuffle()

    val_losses, val_accuracy = validate(model, val_loader, loss)

    print(
        f"Epoch {epoch}, loss: {np.mean(batch_losses)}, val_loss: {np.mean(val_losses)}, val_accuracy: {val_accuracy}"
    )

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 0, loss: 21.31070074198898, val_loss: 1.3070899970677434, val_accuracy: 0.4824


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 1, loss: 1.2986475286742025, val_loss: 1.1590024670776056, val_accuracy: 0.50918


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 2, loss: 1.1988962546994852, val_loss: 1.1326504588127135, val_accuracy: 0.52245


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 3, loss: 1.1706481753913698, val_loss: 1.1089973503229569, val_accuracy: 0.53036


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 4, loss: 1.160849751852935, val_loss: 1.1916433601963277, val_accuracy: 0.47755


  0%|          | 0/757 [00:00<?, ?it/s]

KeyboardInterrupt: 