In [None]:
!pip install datasets transformers torch torchvision

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from datasets import load_dataset
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from safetensors.torch import save_file

In [25]:
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

  scaler = GradScaler()


In [26]:
dataset = load_dataset("JLB-JLB/seizure_eeg_dev")

# Custom dataset class
class SeizureEEGDataset(Dataset):
    def __init__(self, hf_dataset, split='train', transform=None):
        self.data = hf_dataset[split]
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]['image']
        if image.mode != 'L':
            image = image.convert('L')
        label = self.data[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label

In [27]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [28]:
# Create datasets
full_dataset = SeizureEEGDataset(dataset, split='train', transform=transform)

# Create train/validation split
torch.manual_seed(42) # for reproducability
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Hyperparameters
batch_size = 128
num_epochs = 20
learning_rate = 0.001

# Create data loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [29]:
class EEGNet(nn.Module):
    def __init__(self, num_classes=3):
        super(EEGNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [30]:
# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [31]:
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss /= len(val_loader)
    accuracy = 100. * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {accuracy:.2f}%")

print("Training finished!")

  with autocast():


Epoch [1/20], Train Loss: 0.2667, Val Loss: 0.2518, Val Accuracy: 93.11%
Epoch [2/20], Train Loss: 0.2452, Val Loss: 0.2475, Val Accuracy: 93.09%
Epoch [3/20], Train Loss: 0.2404, Val Loss: 0.2346, Val Accuracy: 93.35%
Epoch [4/20], Train Loss: 0.2359, Val Loss: 0.2474, Val Accuracy: 93.39%
Epoch [5/20], Train Loss: 0.2302, Val Loss: 0.2278, Val Accuracy: 93.60%
Epoch [6/20], Train Loss: 0.2270, Val Loss: 0.2306, Val Accuracy: 93.62%
Epoch [7/20], Train Loss: 0.2237, Val Loss: 0.2239, Val Accuracy: 93.49%
Epoch [8/20], Train Loss: 0.2214, Val Loss: 0.2307, Val Accuracy: 93.58%
Epoch [9/20], Train Loss: 0.2164, Val Loss: 0.2115, Val Accuracy: 93.90%
Epoch [10/20], Train Loss: 0.2135, Val Loss: 0.2100, Val Accuracy: 93.88%
Epoch [11/20], Train Loss: 0.2100, Val Loss: 0.2093, Val Accuracy: 94.01%
Epoch [12/20], Train Loss: 0.2086, Val Loss: 0.2077, Val Accuracy: 94.09%
Epoch [13/20], Train Loss: 0.2054, Val Loss: 0.2023, Val Accuracy: 94.23%
Epoch [14/20], Train Loss: 0.2038, Val Loss: 0.

In [34]:
model_state_dict = model.state_dict()
save_file(model_state_dict, "eeg_classifier.safetensors")

In [None]:
dataset = load_dataset("JLB-JLB/seizure_eeg_iirFilter_greyscale_224x224_6secWindow")
test_dataset = dataset['eval']
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model.eval()
total_correct = 0
total_samples = 0

with torch.no_grad():
    for batch in test_dataloader:
        images = batch['image']
        labels = batch['label']

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

accuracy = total_correct / total_samples
print(f'Test Accuracy: {accuracy:.2f}')

README.md:   0%|          | 0.00/846 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/49 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/49 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/49 [00:00<?, ?files/s]

(…)-00000-of-00049-01eedd6162f48c0e.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

(…)-00001-of-00049-4cfde0c5f43042c4.parquet:   0%|          | 0.00/517M [00:00<?, ?B/s]

(…)-00002-of-00049-0ea76ffef3d8f792.parquet:   0%|          | 0.00/507M [00:00<?, ?B/s]

(…)-00003-of-00049-79e34526218de6bf.parquet:   0%|          | 0.00/500M [00:00<?, ?B/s]

(…)-00004-of-00049-7602904820c420ef.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

(…)-00005-of-00049-b469f9a610219b8f.parquet:   0%|          | 0.00/510M [00:00<?, ?B/s]

(…)-00006-of-00049-40cffe94f564d05b.parquet:   0%|          | 0.00/509M [00:00<?, ?B/s]

(…)-00007-of-00049-d1839169d094b755.parquet:   0%|          | 0.00/510M [00:00<?, ?B/s]

(…)-00008-of-00049-d17317963d058652.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

(…)-00009-of-00049-a7e50d2b357a66ad.parquet:   0%|          | 0.00/512M [00:00<?, ?B/s]

(…)-00010-of-00049-f8bfbcb754eb1b47.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

(…)-00011-of-00049-7d917bfbb36b56e3.parquet:   0%|          | 0.00/476M [00:00<?, ?B/s]

(…)-00012-of-00049-666d8802857f2ee6.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

(…)-00013-of-00049-404f795e5690f852.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

(…)-00014-of-00049-fa358b38b86a7f02.parquet:   0%|          | 0.00/481M [00:00<?, ?B/s]

(…)-00015-of-00049-07d74090eb5842a3.parquet:   0%|          | 0.00/501M [00:00<?, ?B/s]