In [None]:
import os
import mne
import numpy as np
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def process_files(patient_files, normal_files, tmin, tmax):
    all_X = []
    all_y = []

    patient_X, patient_y = process_set_files(patient_files, tmin, tmax, label=1)
    all_X.append(patient_X)
    all_y.append(patient_y)

    normal_X, normal_y = process_set_files(normal_files, tmin, tmax, label=0)
    all_X.append(normal_X)
    all_y.append(normal_y)

    all_X = np.concatenate(all_X, axis=0)
    all_y = np.concatenate(all_y, axis=0)
    return all_X, all_y


def process_set_files(set_files, tmin, tmax, label):
    all_X = []
    all_y_group = []
    for set_file in set_files:

        try:
            raw = mne.io.read_raw_eeglab(set_file, preload=True)

        except Exception as e:
            print(f"Error reading file: {e}")
            exit()


        for ch_name in ['VEOG', 'HEOG', 'Trigger']:
            if ch_name in raw.ch_names:
                raw.drop_channels([ch_name])

        raw.filter(0.5, 80., fir_design='firwin', verbose='ERROR')

        events, current_event_id_map = mne.events_from_annotations(raw)
        event_id = {f'event_{i}': current_event_id_map.get(str(40 + i)) for i in range(20) if str(40 + i) in current_event_id_map}

        epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, preload=True, verbose='ERROR', baseline=(0, 0))
        X = epochs.get_data() 

        new_X = []
        new_y_group = []
        for i in range(X.shape[0]):
            for j in range(3):  
                new_X.append(X[i, :, j * int(X.shape[2] / 3): (j + 1) * int(X.shape[2] / 3)])
                new_y_group.append(label)  

        new_X = np.array(new_X)
        new_y_group = np.array(new_y_group)

        all_X.append(new_X)
        all_y_group.append(new_y_group)


    all_X = np.concatenate(all_X, axis=0)
    all_y_group = np.concatenate(all_y_group, axis=0)

    return all_X, all_y_group


class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class EEGNet(nn.Module):
    def __init__(self, num_classes=2, dropout=0.3):
        super(EEGNet, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False),
            nn.BatchNorm2d(16)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(64, 1), stride=(1, 1), groups=16, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
            nn.Dropout(dropout)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)),
            nn.Dropout(dropout)
        )
        self.flatten = nn.Flatten()
        self.classify = nn.Linear(992, num_classes)
    
    def forward(self, x):
        x = x.unsqueeze(1) 
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        x = self.flatten(x)
        out = self.classify(x)
        return out


def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / total
        epoch_accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')


def evaluate_model(model, test_loader):

    model.eval()
    y_pred = []
    y_true = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

    print("\nClassification Report:")
    print(classification_report(y_true, y_pred))
    print("\nConfusion Matrix:")
    print(confusion_matrix(y_true, y_pred))


def k_fold_cross_validation(X, y, k=5, num_epochs=50):

    indices = np.arange(X.shape[0])
    np.random.shuffle(indices)
    X = X[indices]
    y = y[indices]
    kfold = KFold(n_splits=k, shuffle=False)
    fold = 1
    for train_idx, test_idx in kfold.split(X):
        print(f'Fold {fold}/{k}')
        fold += 1

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train.reshape(X_train.shape[0], -1)).reshape(X_train.shape)
        X_test = scaler.transform(X_test.reshape(X_test.shape[0], -1)).reshape(X_test.shape)

        train_dataset = EEGDataset(X_train, y_train)
        test_dataset = EEGDataset(X_test, y_test)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        model = EEGNet(num_classes=2, dropout=0.3).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)

        train_model(model, train_loader, criterion, optimizer, num_epochs)
        evaluate_model(model, test_loader)


if __name__ == "__main__":
    patient_files = ['./061.set', './071.set','./062.set', './072.set','./031.set', './032.set']
    normal_files = ['./007C2.set', './007C1.set','./008C2.set', './008C1.set']

    X, y = process_files(patient_files, normal_files, 0, 3)
    print(f"Processed data shape: {X.shape}, Labels shape: {y.shape}")

    k_fold_cross_validation(X, y, k=5, num_epochs=50)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\072.fdt
Reading 0 ... 881221  =      0.000 ...   881.221 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\031.fdt
Reading 0 ... 894021  =      0.000 ...   894.021 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\032.fdt
Reading 0 ... 871321  =      0.000 ...   871.321 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Reading c:\Users\clock\Desktop\hanzi\007C2.fdt
Reading 0 ... 896621  =      0.000 ...   896.621 secs...
Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\007C1.fdt
Reading 0 ... 881621  =      0.000 ...   881.621 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\008C2.fdt
Reading 0 ... 901822  =      0.000 ...   901.822 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
Reading c:\Users\clock\Desktop\hanzi\008C1.fdt
Reading 0 ... 903022  =      0.000 ...   903.022 secs...


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']


  X = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)


Processed data shape: (7200, 64, 1000), Labels shape: (7200,)
Fold 1/5


  from .autonotebook import tqdm as notebook_tqdm


Epoch [1/50], Loss: 0.1012, Accuracy: 95.85%
Epoch [2/50], Loss: 0.0114, Accuracy: 99.69%
Epoch [3/50], Loss: 0.0024, Accuracy: 99.98%
Epoch [4/50], Loss: 0.0004, Accuracy: 100.00%
Epoch [5/50], Loss: 0.0002, Accuracy: 100.00%
Epoch [6/50], Loss: 0.0001, Accuracy: 100.00%
Epoch [7/50], Loss: 0.0001, Accuracy: 100.00%
Epoch [8/50], Loss: 0.0001, Accuracy: 100.00%
Epoch [9/50], Loss: 0.0001, Accuracy: 100.00%
Epoch [10/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [11/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [12/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [13/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [14/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [15/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [16/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [17/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [18/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [19/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [20/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [21/50], Loss: 0.0000, Accuracy: 100.00%
Epoch [22/50], Loss: 0.00

KeyboardInterrupt: 