In [1]:
%load_ext autoreload
%autoreload 2
from hmpai.pytorch.generators import SAT1DataLoader
import xarray as xr
from pathlib import Path
import torch
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights
from hmpai.pytorch.utilities import DEVICE
from hmpai.normalization import *
import random
import numpy as np
from torchinfo import summary

In [2]:
random.seed(42)
np.random.seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_0_to_1)

# eeg_dataset = EegDataset(data)
train_loader = SAT1DataLoader(train_data)
val_loader = SAT1DataLoader(val_data)
test_loader = SAT1DataLoader(test_data)

In [5]:
train_loader.dataset.data.shape

(12120, 147, 30)

In [14]:
model = SAT1Base(len(dataset.channels), len(dataset.samples), len(dataset.labels)).to(
    DEVICE
)
# model = SAT1Mlp(len(dataset.channels), len(dataset.samples), len(dataset.labels)).to(
#     DEVICE
# )
loss = torch.nn.CrossEntropyLoss(weight=calculate_class_weights(train_loader).to(DEVICE))
opt = torch.optim.NAdam(model.parameters())
epochs = 10

In [12]:
summary(model, (16, 1, 147, 30))

Layer (type:depth-idx)                   Output Shape              Param #
SAT1Base                                 [16, 5]                   --
├─PartialConv2d: 1-1                     [16, 64, 143, 30]         384
├─ReLU: 1-2                              [16, 64, 143, 30]         --
├─MaxPool2d: 1-3                         [16, 64, 71, 30]          --
├─Conv2d: 1-4                            [16, 128, 69, 30]         24,704
├─ReLU: 1-5                              [16, 128, 69, 30]         --
├─MaxPool2d: 1-6                         [16, 128, 34, 30]         --
├─Conv2d: 1-7                            [16, 256, 32, 30]         98,560
├─ReLU: 1-8                              [16, 256, 32, 30]         --
├─MaxPool2d: 1-9                         [16, 256, 16, 30]         --
├─Flatten: 1-10                          [16, 122880]              --
├─Linear: 1-11                           [16, 128]                 15,728,768
├─ReLU: 1-12                             [16, 128]                 -

In [15]:
for epoch in range(epochs):
    batch_losses = train(model, train_loader, opt, loss)

    # Shuffle data before next epoch
    train_loader.shuffle()

    val_losses, val_accuracy = validate(model, val_loader, loss)

    print(
        f"Epoch {epoch}, loss: {np.mean(batch_losses)}, val_loss: {np.mean(val_losses)}, val_accuracy: {val_accuracy}"
    )

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 0, loss: 1.1649976341223622, val_loss: 1.046687555069826, val_accuracy: 0.55638


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 1, loss: 0.9716742517453676, val_loss: 0.6481188881762173, val_accuracy: 0.75944


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 2, loss: 0.6981636209012337, val_loss: 0.556106546186671, val_accuracy: 0.79031


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 3, loss: 0.6198194572404322, val_loss: 0.5475290050920175, val_accuracy: 0.79286


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 4, loss: 0.5792121211438387, val_loss: 0.515024719737014, val_accuracy: 0.80918


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 5, loss: 0.5508768406341067, val_loss: 0.5058419232465783, val_accuracy: 0.81684


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 6, loss: 0.5281009227332697, val_loss: 0.48963725466509256, val_accuracy: 0.82168


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 7, loss: 0.5053418743641878, val_loss: 0.5071086904376138, val_accuracy: 0.81582


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 8, loss: 0.4925675328668389, val_loss: 0.4761347716864274, val_accuracy: 0.82041


  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

Epoch 9, loss: 0.471193716884683, val_loss: 0.508051329364582, val_accuracy: 0.81888
