## 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [1]:
import random

import numpy as np
import torch

from pyhealth.datasets import TUEVDataset
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## 2. Load TUEV Dataset
Point to the TUEV dataset root and load the dataset.

In [3]:
dataset = TUEVDataset(
    root='/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf',  # Update this path
    dev=True
)
dataset.stats()

No config path provided, using default config
Using both train and eval subsets
Wrote train metadata to cache: /home/jp65/.cache/pyhealth/tuev/tuev-train-pyhealth.csv
Wrote eval metadata to cache: /home/jp65/.cache/pyhealth/tuev/tuev-eval-pyhealth.csv
Using cached metadata from /home/jp65/.cache/pyhealth/tuev
Initializing tuev dataset from /home/jp65/.cache/pyhealth/tuev (dev mode: True)
No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7
Scanning table: train from /home/jp65/.cache/pyhealth/tuev/tuev-train-pyhealth.csv
Scanning table: eval from /home/jp65/.cache/pyhealth/tuev/tuev-eval-pyhealth.csv
Dev mode enabled: limiting to 1000 patients
Caching event dataframe to /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/global_event_df.parquet...




Dataset: tuev
Dev mode: True
Number of patients: 189
Number of events: 259


## 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification.

In [4]:
sample_dataset = dataset.set_task(EEGEventsTUEV(
    resample_rate=200,    # Resample rate
    bandpass_filter=(0.1, 75.0),    # Bandpass filter
    notch_filter=50.0,    # Notch filter
))

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

Setting task EEG_events for tuev base dataset...
Applying task transformations on data with 1 workers...
Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 189 patients. (Polars threads: 128)


  0%|          | 0/189 [00:00<?, ?it/s]

Rank 0 inferred the following `['bytes']` data format.


100%|██████████| 189/189 [05:03<00:00,  1.61s/it]

Worker 0 finished processing patients.
Fitting processors on the dataset...





Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}
Processing samples and saving to /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/tasks/EEG_events_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_47c27255-9fc0-5271-bd99-638ffecdb1cc.ld...
Applying processors on data with 1 workers...
Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 53377 samples. (0 to 53377)


  0%|          | 0/53377 [00:00<?, ?it/s]

Rank 0 inferred the following `['str', 'pickle', 'tensor', 'int', 'tensor']` data format.


100%|██████████| 53377/53377 [00:50<00:00, 1066.60it/s]

Worker 0 finished processing samples.
Cached processed samples to /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/tasks/EEG_events_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_47c27255-9fc0-5271-bd99-638ffecdb1cc.ld
Total task samples: 53377
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}

Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])
Signal shape: torch.Size([16, 1000])
Label: 5





## 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [5]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

Train/Val/Test sizes: 37363, 5338, 10676


## 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and data structure.

In [6]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

Batch structure:
  patient_id: list(len=32)
  signal_file: list(len=32)
  signal: Tensor(shape=(32, 16, 1000))
  offending_channel: list(len=32)
  label: Tensor(shape=(32,))


## 6. Instantiate Model
Create a simple CNN model for EEG event classification.

In [9]:
import torch.nn as nn

class SimpleEEGClassifier(nn.Module):
    def __init__(self, num_classes=6):
        super().__init__()
        self.conv1 = nn.Conv1d(16, 32, kernel_size=5, stride=1)
        self.pool = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5)
        # After conv1: 1280 -> 1276, pool: 638
        # conv2: 638 -> 634, pool: 317
        self.fc1 = nn.Linear(15808, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, signal):
        x = self.relu(self.conv1(signal))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleEEGClassifier(num_classes=6).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 2,037,222
Trainable parameters: 2,037,222


## 7. Test Forward Pass
Verify the model can process a batch and compute outputs.

In [10]:
# Move batch to device
test_batch = {key: value.to(device) if hasattr(value, 'to') else value 
              for key, value in first_batch.items()}

# Forward pass
with torch.no_grad():
    output = model(test_batch['signal'])

print("Model output shape:", output.shape)
print("Sample output:", output[0])

Model output shape: torch.Size([32, 6])
Sample output: tensor([ 0.4550, -0.1300,  0.4457, -0.0651,  0.2269, -0.5862], device='cuda:0')


## 8. Configure Loss and Optimizer
Define the loss function and optimizer for training.

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## 9. Train the Model
Launch the training loop to learn from the EEG data.

In [12]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        signals = batch['signal'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(signals)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")
    
    # Validation
    if val_loader:
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                signals = batch['signal'].to(device)
                labels = batch['label'].to(device)
                outputs = model(signals)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                predicted = torch.argmax(outputs, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

Epoch 1/5, Loss: 0.6861
Validation Loss: 0.2473, Accuracy: 93.95%
Epoch 2/5, Loss: 0.2535
Validation Loss: 0.1431, Accuracy: 96.33%
Epoch 3/5, Loss: 0.4147
Validation Loss: 0.3619, Accuracy: 88.27%
Epoch 4/5, Loss: 0.1992
Validation Loss: 0.1431, Accuracy: 96.97%
Epoch 5/5, Loss: 0.1228
Validation Loss: 0.1890, Accuracy: 96.25%


## 10. Evaluate on Test Set
Evaluate the trained model on the test set.

In [13]:
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        signals = batch['signal'].to(device)
        labels = batch['label'].to(device)
        outputs = model(signals)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        predicted = torch.argmax(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

Test Loss: 0.1885, Accuracy: 96.11%
