## 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [None]:
import random

import numpy as np
import torch

from pyhealth.datasets import TUABDataset
from pyhealth.tasks import EEGAbnormalTUAB
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

## 2. Load TUAB Dataset
Point to the TUAB dataset root and load the dataset.

In [None]:
dataset = TUABDataset(
    root='/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf',  # Update this path
    dev=True
)
dataset.stats()

## 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification.

In [None]:
sample_dataset = dataset.set_task(EEGAbnormalTUAB(
    resample_rate=200,    # Resample rate
    bandpass_filter=(0.1, 75.0),    # Bandpass filter
    notch_filter=50.0,    # Notch filter
))

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

## 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [None]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

## 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and data structure.

In [None]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

## 6. Instantiate Model
Create a simple CNN model for EEG event classification.

In [None]:
import torch.nn as nn

class SimpleEEGClassifier(nn.Module):
    def __init__(self, num_classes=6):
        super().__init__()
        self.conv1 = nn.Conv1d(16, 32, kernel_size=5, stride=1)
        self.pool = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5)
        # After conv1: 1280 -> 1276, pool: 638
        # conv2: 638 -> 634, pool: 317
        self.fc1 = nn.Linear(31808, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, signal):
        x = self.relu(self.conv1(signal))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleEEGClassifier(num_classes=1).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Test Forward Pass
Verify the model can process a batch and compute outputs.

In [None]:
# Move batch to device
test_batch = {key: value.to(device) if hasattr(value, 'to') else value 
              for key, value in first_batch.items()}

# Forward pass
with torch.no_grad():
    output = model(test_batch['signal'])

print("Model output shape:", output.shape)
print("Sample output:", output[0])

## 8. Configure Loss and Optimizer
Define the loss function and optimizer for training.

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## 9. Train the Model
Launch the training loop to learn from the EEG data.

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        signals = batch['signal'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(signals)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")
    
    # Validation
    if val_loader:
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                signals = batch['signal'].to(device)
                labels = batch['label'].to(device)
                outputs = model(signals)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                predicted = torch.sigmoid(outputs)
                predicted = (predicted > 0.5).float()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

## 10. Evaluate on Test Set
Evaluate the trained model on the test set.

In [None]:
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        signals = batch['signal'].to(device)
        labels = batch['label'].to(device)
        outputs = model(signals)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        predicted = torch.sigmoid(outputs)
        predicted = (predicted > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")