In [9]:
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import hydra
from omegaconf import DictConfig # Operate configs as a dict
import wandb
from termcolor import cprint
from tqdm import tqdm
from torcheeg.datasets import DEAPDataset
from torcheeg import transforms

from src.datasets import ThingsMEGDataset
from src.models import BasicConvClassifier
from src.utils import set_seed
from src.preprocess import CAR
from src.widget import plot_raw_signal

## Data loader

### Load data

In [10]:
set_seed(1234)

# Load raw data
loader_args = {"batch_size": 128, "num_workers": 4}
train_set = ThingsMEGDataset("train", "data")
val_set = ThingsMEGDataset("val", "data")
test_set = ThingsMEGDataset("test", "data")

### Preprocessing

- リサンプリング
- フィルタリング
- スケーリング: z-score normalization
- ベースライン補正: common median reference -> done

In [18]:
from torcheeg.datasets.constants import \
    DEAP_CHANNEL_LOCATION_DICT

DEAP_CHANNEL_LOCATION_DICT

{'FP1': [0, 3],
 'AF3': [1, 3],
 'F3': [2, 2],
 'F7': [2, 0],
 'FC5': [3, 1],
 'FC1': [3, 3],
 'C3': [4, 2],
 'T7': [4, 0],
 'CP5': [5, 1],
 'CP1': [5, 3],
 'P3': [6, 2],
 'P7': [6, 0],
 'PO3': [7, 3],
 'O1': [8, 3],
 'OZ': [8, 4],
 'PZ': [6, 4],
 'FP2': [0, 5],
 'AF4': [1, 5],
 'FZ': [2, 4],
 'F4': [2, 6],
 'F8': [2, 8],
 'FC6': [3, 7],
 'FC2': [3, 5],
 'CZ': [4, 4],
 'C4': [4, 6],
 'T8': [4, 8],
 'CP6': [5, 7],
 'CP2': [5, 5],
 'P4': [6, 6],
 'P8': [6, 8],
 'PO4': [7, 5],
 'O2': [8, 5]}

In [85]:
plot_raw_signal(train_set)

In [None]:
# 
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args) # [n, ch, seq]
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args) # [n, ch, seq]
test_loader = torch.utils.data.DataLoader(
    test_set, shuffle=False, batch_size=128, num_workers=4
    ) # [n, ch, seq]