In [1]:
import os
# Please replace it with your actual catalog
os.chdir('/home/xz/workspace/github')

from train.utils import *
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader
from train.GACET import GACET
from train.trainer import Trainer

In [3]:
set_seed(42)
g = torch.Generator().manual_seed(42)
data_SampEn = Path('./data/SampEn/sub-01').resolve()
data_DE = Path('./data/DE/sub-01').resolve()
task_order = ['MATB_level0.pkl', 'MATB_level1.pkl', 'MATB_level2.pkl', 'MATB_level3.pkl', 'MATB_level4.pkl']
acc_list = []
day_permutations = [
	([1], [2], [3]),
	([1], [3], [2]),
	([2], [3], [1])
]
for day1, day2, day_test in day_permutations:
	print(f'training on {day1}, {day2} and testing on {day_test}')
	dataset_train_1 = DualSourceDataset(
		data_SampEn, data_DE,
		days=day1,
		task_order=task_order
	)
	dataset_train_2 = DualSourceDataset(
		data_SampEn, data_DE,
		days=day2,
		task_order=task_order
	)
	dataset_test = DualSourceDataset(
		data_SampEn, data_DE,
		days=day_test,
		task_order=task_order
	)
	len_train_1, len_train_2 = len(dataset_train_1), len(dataset_train_2)
	if len_train_1 != len_train_2:
		min_len = min(len_train_1, len_train_2)
		if len_train_1 > min_len:
			dataset_train_1.trim_to_length(min_len)
		if len_train_2 > min_len:
			dataset_train_2.trim_to_length(min_len)

	kf = StratifiedKFold(n_splits=5, shuffle=False)
	for fold, (train_idx, val_idx) in enumerate(kf.split(dataset_train_1.data[0], dataset_train_1.labels)):
		print(f'fold {fold+1} start')
		splitter = DualSourceDataSplitter(dataset_train_1, dataset_train_2, train_idx, val_idx)
		combined_train = splitter.train_dataset
		combined_val = splitter.val_dataset

		standardized_train = StandardizedDataset(combined_train, is_train=True)
		mean, std = standardized_train.get_mean_std()
		standardized_val = StandardizedDataset(combined_val, is_train=False, mean=mean, std=std)
		standardized_test = StandardizedDataset(dataset_test, is_train=False, mean=mean, std=std)

		train_loader = DataLoader(standardized_train, batch_size=32, shuffle=True,generator=g)
		val_loader = DataLoader(standardized_val, batch_size=32, shuffle=False,generator=g)
		test_loader = DataLoader(standardized_test, batch_size=32, shuffle=False,generator=g)

		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		model = GACET(num_classes=len(task_order), embed_dim=300)
		model.to(device)
		optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
		criterion = torch.nn.CrossEntropyLoss()

		trainer = Trainer(model, train_loader, val_loader, test_loader, criterion, optimizer, device)
		acc = trainer.train()
		acc_list.append(acc)

print((f"acc: {np.mean(acc_list) * 100:.2f}%"))

training on [1], [2] and testing on [3]
fold 1 start
The Best Validation Accuracy: 60.71%
Loaded best model state from training.
Test Accuracy: 33.81%
fold 2 start
The Best Validation Accuracy: 65.48%
Loaded best model state from training.
Test Accuracy: 39.05%
fold 3 start
The Best Validation Accuracy: 60.71%
Loaded best model state from training.
Test Accuracy: 33.81%
fold 4 start
The Best Validation Accuracy: 60.71%
Loaded best model state from training.
Test Accuracy: 41.43%
fold 5 start
The Best Validation Accuracy: 61.90%
Loaded best model state from training.
Test Accuracy: 39.05%
training on [1], [3] and testing on [2]
fold 1 start
The Best Validation Accuracy: 64.29%
Loaded best model state from training.
Test Accuracy: 43.81%
fold 2 start
The Best Validation Accuracy: 63.10%
Loaded best model state from training.
Test Accuracy: 40.48%
fold 3 start
The Best Validation Accuracy: 63.10%
Loaded best model state from training.
Test Accuracy: 45.71%
fold 4 start
The Best Validation

The results above present the predictions for Subject 1 using DE and Sampen. The outcomes of three rounds of five-fold cross-validation are consistent with the Subject_1 Performance results presented in Section F2.4: Dataset 2 (5-class) on page 35. The average value of 44.41% also matches the result reported in Table 19 on page 24.
