In this quick tour, we highlight the ease of starting an EEG analysis research with only modifying a few lines of PyTorch tutorial.
The torcheeg.datasets
module contains dataset objects for many real-world EEG data, such as DEAP, DREAMER, and SEED. In this tutorial, we use the DEAP
dataset. Each Dataset
contains three parameters: online_transform
, offline_transform
, and target_transform
, which are used to modify samples and labels, respectively.
from torcheeg.datasets import DEAPDataset
from torcheeg.datasets.constants.emotion_recognition.deap import DEAP_CHANNEL_LOCATION_DICT
dataset = DEAPDataset(io_path=f'./deap',
root_path='./data_preprocessed_python',
offline_transform=transforms.Compose([
transforms.BandDifferentialEntropy(),
transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
]),
online_transform=transforms.ToTensor(),
label_transform=transforms.Compose([
transforms.Select('valence'),
transforms.Binary(5.0),
]))
Here, offline_transform
is used to modify samples when generating and processing intermediate results, online_transform
is used to modify samples during operation, andtarget_transform
is used to modify labels. We strongly recommend placing time-consuming numpy transforms in offline_transform
, and pytorch and data augmentation related transforms in online_transform
.
Next, we need to divide the dataset into a training set and a test set. In the field of EEG analysis, commonly used data partitioning methods include k-fold cross-validation and leave-one-out cross-validation. In this tutorial, we use k-fold cross-validation on the entire dataset (KFoldDataset
) as an example for dataset partitioning.
from torcheeg.model_selection import KFoldDataset
k_fold = KFoldDataset(n_splits=5, split_path='./split', shuffle=True)
Let's define a simple but effective CNN model:
class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.ZeroPad2d((1, 2, 1, 2)),
nn.Conv2d(4, 64, kernel_size=4, stride=1),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.ZeroPad2d((1, 2, 1, 2)),
nn.Conv2d(64, 128, kernel_size=4, stride=1),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.ZeroPad2d((1, 2, 1, 2)),
nn.Conv2d(128, 256, kernel_size=4, stride=1),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.ZeroPad2d((1, 2, 1, 2)),
nn.Conv2d(256, 64, kernel_size=4, stride=1),
nn.ReLU()
)
self.lin1 = nn.Linear(9 * 9 * 64, 1024)
self.lin2 = nn.Linear(1024, 2)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = x.flatten(start_dim=1)
x = self.lin1(x)
x = self.lin2(x)
return x
During the research, we may also use other GNN or Transformer-based models and build more complex projects. Please refer to the examples in the exmaples/
folder.
The training and validation scripts for the model are taken from the PyTorch tutorial without much modification. The only thing worth noting is that the Dataset
provides three values when it is traversed, namely the EEG signal (denoted by X
in the code), the baseline signal (denoted by b
in the code), and the sample label (denoted by y
in the code). In particular, to achieve baseline removal, we subtract the baseline signal from the original signal as input to the model (see pred = model(X - b)
).
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
batch_size = 64
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch_idx, batch in enumerate(dataloader):
X = batch[0].to(device)
b = batch[1].to(device)
y = batch[2].to(device)
# Compute prediction error
pred = model(X - b)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
loss, current = loss.item(), batch_idx * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def valid(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
val_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X = batch[0].to(device)
b = batch[1].to(device)
y = batch[2].to(device)
pred = model(X - b)
val_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
val_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n"
)
for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_loader, model, loss_fn, optimizer)
valid(val_loader, model, loss_fn)
print("Done!")