In [11]:
import sys
sys.path.append("./MNE-Python tutorial")

In [15]:
from bnci import *
import torch
import torchvision.transforms as transforms 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [16]:
DATASET_ROOT = "C:/Users/NMAIL/Desktop/BNCI dataset"

In [19]:
# Custom Dataset class for MNE EpochsArray data
class EEGDataset(Dataset):
    def __init__(self, epochs, labels):
        """
        epochs: MNE's EpochsArray object
        labels: Corresponding labels for each epoch
        """
        self.epochs = epochs.get_data()  # (n_epochs, n_channels, n_times)
        self.labels = labels  # e.g., (n_epochs,)
        
    def __len__(self):
        return len(self.epochs)
    
    def __getitem__(self, idx):
        # Get one sample and its label
        sample = self.epochs[idx]  # Shape: (n_channels, n_times)
        label = self.labels[idx]   # Corresponding label
        
        # Convert the sample to PyTorch tensor and reshape to (1, n_times, n_channels)
        sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1)
        label = torch.tensor(label, dtype=torch.long)
        
        return sample, label

In [76]:
epochs_arrays = []
classes = []
for i in range(1,10):
    res_data, res_class = get_data_2a(i, True)
    epochs_array = EEG_to_epochs(res_data, res_class,sfreq = 250,)
    epochs_arrays.append(epochs_array)
    classes.append(res_class)

concat_epochs = mne.concatenate_epochs(epochs_arrays)
concat_labels = np.concatenate(classes)

dataset = EEGDataset(concat_epochs, concat_labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)


https://tutorials.pytorch.kr/beginner/basics/buildmodel_tutorial.html  
https://tutorials.pytorch.kr/beginner/blitz/cifar10_tutorial.html  
참조

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


Using cuda device


In [72]:
class ShallowConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.tempConv = nn.Conv2d(1, 40, (25,1))
        self.spaConv = nn.Conv2d(40, 40, (1,22)) 
        # Batch Normalization
        self.batch_norm = nn.BatchNorm2d(40)

        self.pool = nn.AvgPool2d((75,1), (15,1))

        # Linear Layer for Classification
        self.fc = nn.Linear(4760, 4)
    
    def forward(self, x):
        # Temporal Convolution
        x = self.tempConv(x)
        x = self.batch_norm(x)
        x = self.spaConv(x)
        
        # Mean Pooling
        x = self.pool(x)
        
        # Log Activation
        x = torch.log(torch.clamp(x, min=1e-6))  # To avoid log(0)
        
        # Flatten the output for the dense layer
        x = x.view(x.size(0), -1)
        
        # Linear Classification
        x = self.fc(x)
        
        return F.softmax(x, dim=1)
    
net = ShallowConvNet()
net.to(device) #using gpu


ShallowConvNet(
  (tempConv): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
  (spaConv): Conv2d(40, 40, kernel_size=(1, 22), stride=(1, 1))
  (batch_norm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
  (fc): Linear(in_features=4760, out_features=4, bias=True)
)

In [69]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [70]:
num_epochs = 10

In [75]:
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()  # Zero the parameter gradients
        
        outputs = net(inputs)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        
        loss.backward()  # Backward pass (compute gradients)
        optimizer.step()  # Update weights
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, Accuracy: {accuracy:.2f}%")

print("Training complete.")

Epoch [1/10], Loss: 119.1721, Accuracy: 25.50%
Epoch [2/10], Loss: 119.2299, Accuracy: 25.35%
Epoch [3/10], Loss: 119.1642, Accuracy: 25.50%
Epoch [4/10], Loss: 119.1997, Accuracy: 25.23%
Epoch [5/10], Loss: 119.1657, Accuracy: 25.46%
Epoch [6/10], Loss: 119.1736, Accuracy: 25.04%
Epoch [7/10], Loss: 119.1990, Accuracy: 25.31%
Epoch [8/10], Loss: 119.1279, Accuracy: 25.27%
Epoch [9/10], Loss: 119.1015, Accuracy: 25.42%
Epoch [10/10], Loss: 119.1087, Accuracy: 25.46%
Training complete.


In [77]:
evaluation_epochs_arrays = []
evaluation_classes = []
for i in range(1,10):
    res_data, res_class = get_data_2a(i, False)
    epochs_array = EEG_to_epochs(res_data, res_class,sfreq = 250,)
    evaluation_epochs_arrays.append(epochs_array)
    evaluation_classes.append(res_class)

ev_concat_epochs = mne.concatenate_epochs(evaluation_epochs_arrays)
ev_concat_labels = np.concatenate(evaluation_classes)

ev_dataset = EEGDataset(ev_concat_epochs, ev_concat_labels)
dataloader = DataLoader(ev_dataset, batch_size=32, shuffle=True)

(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
(288, 22, 1875)


In [78]:
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
    
        # 신경망에 이미지를 통과시켜 출력을 계산합니다
        outputs = net(inputs)
        
        # 가장 높은 값(energy)를 갖는 분류(class)를 정답으로 선택하겠습니다
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

Accuracy of the network on the 10000 test images: 24 %
