In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
root_path = r"G:\EEG_TAC"
os.chdir(root_path)
print("current work dir:", os.getcwd())

current work dir: G:\EEG_TAC


In [3]:
import os
import logging
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import random
from sklearn.metrics import f1_score
import torch.nn.functional as F
import shutil

from utils import set_logging_config, dict_to_markdown, set_seed
from datasets import DEAPDataset
from datasets.transforms import (BandDifferentialEntropy, Compose, ToGrid, ToTensor, To2d,
                                 Select, Binary, BaselineRemoval)
from datasets.constants import DEAP_CHANNEL_LOCATION_DICT, DEAP_LOCATION_LIST, STANDARD_1005_CHANNEL_LOCATION_DICT, \
    STANDARD_1020_CHANNEL_LOCATION_DICT
from model_selection import KFoldPerSubjectCrossTrial, KFoldPerSubjectGroupbyTrial, KFoldPerSubject

from models.model import build_model
from engine import train_model_per_subject, save_result

In [4]:
dataset_name = "deap"
label = "arousal"
chunk_size = 128 * 4


args = {
    "dataset_name": dataset_name,
    "label": label,  # valence
    "data_dir": r'G:\Data\EEG-Data\deap\data_preprocessed_python',
    "feature_dir": f"./processed_data/{dataset_name}_{label}_{chunk_size}",
    "chunk_size": chunk_size,
    "results_dir": "./ckpts/",
    "split_mode": "no_shuffle",  # "per_cross"
    "subject_num": 32,

    "model_name": "ERM",
    "num_classes": 2,
    "in_channel": 32,
    "embed_size": 64,   # best 64

    # albation
    "graph_variant": "time_gcn", # [time_gcn, time_cnn, time_att, fft_cnn, fft_att, fft_hyp]
    "use_wsd": True,
    "use_dfhc": True,
    "wavelet_level": 2,  # [2,3,4,5]  # best 2
    "base_fun": "sine", # [linear, gauss, sine]
    "graph_layer": 2,  # [1,2,3,4]  # best 2
    "add_noise": False,
    "noise_std": 0.05,   # std

    "seed": 42,
    "max_epochs": 15,
    "batch_size": 12,
    "kflod": 10,
    "lr": 0.0008,
    "weight_decay": 0.0001,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

In [5]:
set_seed(args['seed'])

log_dir = os.path.join(
    args['results_dir'],
    args['dataset_name'],
    args['label'],
    f"{args['split_mode']}_var{args['graph_variant']}_wsd{'on' if args['use_wsd'] else 'off'}_dfhc{'on' if args['use_dfhc'] else 'off'}_wave{args['wavelet_level']}_basefun{args['base_fun']}_layer{args['graph_layer']}_kflod{args['kflod']}_emb{args['embed_size']}_new"
)
# os.path.join(args['results_dir'], args['dataset_name'], args['label'], f"{args['dataset_name']}_{args['label']}_{args['model_name']}_{args['split_mode']}_250116")
args["results_dir"] = log_dir
set_logging_config(log_dir)
logger = logging.getLogger("main")
logger.info(f" Args:\n {dict_to_markdown(args)}")

[2025-05-10 00:22:50,381] [main]  Args:
 |               | 0                                                                                                   |
|:--------------|:----------------------------------------------------------------------------------------------------|
| dataset_name  | deap                                                                                                |
| label         | arousal                                                                                             |
| data_dir      | G:\Data\EEG-Data\deap\data_preprocessed_python                                                      |
| feature_dir   | ./processed_data/deap_arousal_512                                                                   |
| chunk_size    | 512                                                                                                 |
| results_dir   | ./ckpts/deap\arousal\no_shuffle_vartime_gcn_wsdon_dfhcon_wave2_basefunsine_layer2_kflod10_emb64_new |

In [6]:
dataset = DEAPDataset(io_path=args['feature_dir'], root_path=args['data_dir'], chunk_size=args["chunk_size"],
                      online_transform=Compose([To2d(), ToTensor()]),
                      label_transform=Compose([Select(args['label']), Binary(5.0), ]),
                      num_worker=0)

logger.info(f"Sample shape: {dataset[0][0].shape}\tTotal samples: {len(dataset)}")

if args['split_mode'] == 'per_cross':
    cv = KFoldPerSubjectCrossTrial(n_splits=args['kflod'], shuffle=True, split_path=os.path.join(log_dir, f"split_kflod_{args['kflod']}"))
elif args['split_mode'] == 'per_groupby':
    cv = KFoldPerSubjectGroupbyTrial(n_splits=args['kflod'], shuffle=False,
                                     split_path=os.path.join(log_dir, f"split_kflod_{args['kflod']}"))
elif args['split_mode'] == 'no_shuffle':
    cv = KFoldPerSubject(n_splits=args['kflod'], shuffle=False, split_path=os.path.join(log_dir, f"split_kflod_{args['kflod']}"))
elif args['split_mode'] == 'shuffle':
    cv = KFoldPerSubject(n_splits=args['kflod'], shuffle=True, split_path=os.path.join(log_dir, f"split_kflod_{args['kflod']}"))
else:
    raise NameError

[2025-05-10 00:22:50,653] [torcheeg] 🔍 | Detected cached processing results, reading cache from ./processed_data/deap_arousal_512.
[2025-05-10 00:22:50,789] [main] Sample shape: torch.Size([1, 32, 512])	Total samples: 19200


In [12]:
torch.cuda.empty_cache()
model = build_model(args)

logger.info(f"Model: {args['model_name']}, Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.4f} M")
for name, module in model.named_children():
    total_params = sum(p.numel() for p in module.parameters()) / 1e6
    logger.info(f"{name}: {total_params:.4f}M")
                
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

final_results_dict = train_model_per_subject(model=model, dataset=dataset, optimizer=optimizer, cv=cv,
                            subject_num=args['subject_num'], max_epochs=args['max_epochs'], 
                            batch_size=args['batch_size'], device=args['device'], logger=logger, log_dir=log_dir, sub_idx=[1])

[2025-05-10 00:28:06,912] [main] Model: ERM, Number of parameters: 0.2276 M
[2025-05-10 00:28:06,913] [main] embed: 0.0021M
[2025-05-10 00:28:06,913] [main] wavelet: 0.1852M
[2025-05-10 00:28:06,914] [main] graph_conv: 0.0170M
[2025-05-10 00:28:06,915] [main] base_conv: 0.0205M
[2025-05-10 00:28:06,915] [main] base_scale: 0.0001M
[2025-05-10 00:28:06,916] [main] classifier: 0.0026M
[2025-05-10 00:28:06,917] [main] --------------------------------------------------
[2025-05-10 00:28:06,918] [torcheeg] 📊 | Detected existing split of train and test set, use existing split from ./ckpts/deap\arousal\no_shuffle_vartime_gcn_wsdon_dfhcon_wave2_basefunsine_layer2_kflod10_emb64_new\split_kflod_10.
[2025-05-10 00:28:06,918] [torcheeg] 💡 | If the dataset is re-generated, you need to re-generate the split of the dataset instead of using the previous split.


MemoryError: ./processed_data/deap_arousal_512\_record_0\eeg: Not enough space

Error in callback <bound method AutoreloadMagics.post_execute_hook of <IPython.extensions.autoreload.AutoreloadMagics object at 0x00000242D6741300>> (for post_execute), with arguments args (),kwargs {}:


MemoryError: 

In [None]:
# the mean of all subjects
final_results_csv_path = os.path.join(log_dir, 'all_subjects_mean_results.csv')
filtered_results_csv_path = os.path.join(log_dir, 'filtered_subjects_mean_results.csv')

save_result(final_results_dict, final_results_csv_path, filtered_results_csv_path, exclude_counts = None, logger=logger)

In [None]:
src_file = "train_deap.ipynb"
dst = os.path.join(log_dir, src_file)
        
shutil.copy(src_file, dst)
print(f"File has saved into: {dst}")