## 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [1]:
import random

import numpy as np
import torch

from pyhealth.datasets import TUEVDataset
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader
from pyhealth.models import BIOT

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## 2. Load TUEV Dataset
Point to the TUEV dataset root and load the dataset.

In [2]:
dataset = TUEVDataset(
    root='/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf',  # Update this path
    dev=True
)
dataset.stats()

No config path provided, using default config
Using both train and eval subsets
Using cached metadata from /home/jp65/.cache/pyhealth/tuev
Initializing tuev dataset from /home/jp65/.cache/pyhealth/tuev (dev mode: True)
No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7
Dataset: tuev
Dev mode: True
Number of patients: 189
Number of events: 259


## 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification.

In [3]:
sample_dataset = dataset.set_task(EEGEventsTUEV(
    resample_rate=200,    # Resample rate
    bandpass_filter=(0.1, 75.0),    # Bandpass filter
    notch_filter=50.0,    # Notch filter
))

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

Setting task EEG_events for tuev base dataset...
Found cached processed samples at /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/tasks/EEG_events_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_47c27255-9fc0-5271-bd99-638ffecdb1cc.ld, skipping processing.
Total task samples: 53377
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}

Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])
Signal shape: torch.Size([16, 1000])
Label: 5


## 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [4]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

Train/Val/Test sizes: 37363, 5338, 10676


## 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and data structure.

In [5]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

Batch structure:
  patient_id: list(len=32)
  signal_file: list(len=32)
  signal: Tensor(shape=(32, 16, 1000))
  offending_channel: list(len=32)
  label: Tensor(shape=(32,))


## 6. Instantiate BIOT


In [6]:
model = BIOT(
    dataset=sample_dataset,
    emb_size=256,
    heads= 8,
    depth=4,
    n_fft=200,
    hop_length=100,
    n_classes=6,
    n_channels=16).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 3,187,734
Trainable parameters: 3,187,718


## 7. Test Forward Pass
Verify the model can process a batch and compute outputs.

In [7]:
# Move batch to device
test_batch = {key: value.to(device) if hasattr(value, 'to') else value 
              for key, value in first_batch.items()}

# Forward pass
with torch.no_grad():
    outputs = model(**test_batch)

print("Output keys:", outputs.keys())
print(f"Loss: {outputs['loss'].item():.4f}")
print(f"Logits shape: {outputs['logit'].shape}")
print(f"y_prob shape: {outputs['y_prob'].shape}")

# Get embeddings
embeddings = model.get_embeddings(**test_batch)
print(f"Embeddings shape: {embeddings['embeddings'].shape}")

  return _VF.stft(  # type: ignore[attr-defined]


Output keys: dict_keys(['loss', 'y_prob', 'y_true', 'logit'])
Loss: 119.1745
Logits shape: torch.Size([32, 6])
y_prob shape: torch.Size([32, 6])
Embeddings shape: torch.Size([32, 256])


## 8. Train Model
Train the model using PyHealth's Trainer:

In [11]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device=device,metrics=["balanced_accuracy", "cohen_kappa"],)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=3,
    monitor="cohen_kappa",
    
)

BIOT(
  (biot): BIOTClassifier(
    (biot): BIOTEncoder(
      (patch_embedding): PatchFrequencyEmbedding(
        (projection): Linear(in_features=101, out_features=256, bias=True)
      )
      (transformer): LinearAttentionTransformer(
        (layers): SequentialSequence(
          (layers): ModuleList(
            (0-3): 4 x ModuleList(
              (0): PreNorm(
                (fn): SelfAttention(
                  (local_attn): LocalAttention(
                    (dropout): Dropout(p=0.2, inplace=False)
                  )
                  (to_q): Linear(in_features=256, out_features=256, bias=False)
                  (to_k): Linear(in_features=256, out_features=256, bias=False)
                  (to_v): Linear(in_features=256, out_features=256, bias=False)
                  (to_out): Linear(in_features=256, out_features=256, bias=True)
                  (dropout): Dropout(p=0.2, inplace=False)
                )
                (norm): LayerNorm((256,), eps=1e-05, elementwise

Epoch 0 / 3:   0%|          | 0/1168 [00:00<?, ?it/s]

--- Train epoch-0, step-1168 ---
loss: 0.9935


Evaluation: 100%|██████████| 167/167 [00:04<00:00, 41.07it/s]


--- Eval epoch-0, step-1168 ---
balanced_accuracy: 0.6476
cohen_kappa: 0.8055
loss: 0.2974
New best cohen_kappa score (0.8055) at epoch-0, step-1168



Epoch 1 / 3:   0%|          | 0/1168 [00:00<?, ?it/s]

--- Train epoch-1, step-2336 ---
loss: 0.3636


Evaluation: 100%|██████████| 167/167 [00:04<00:00, 38.73it/s]

--- Eval epoch-1, step-2336 ---
balanced_accuracy: 0.7337
cohen_kappa: 0.8751
loss: 0.1913
New best cohen_kappa score (0.8751) at epoch-1, step-2336








Epoch 2 / 3:   0%|          | 0/1168 [00:00<?, ?it/s]

--- Train epoch-2, step-3504 ---
loss: 0.2402


Evaluation: 100%|██████████| 167/167 [00:03<00:00, 44.24it/s]

--- Eval epoch-2, step-3504 ---
balanced_accuracy: 0.7485
cohen_kappa: 0.8929
loss: 0.1738
New best cohen_kappa score (0.8929) at epoch-2, step-3504





Loaded best model


## 9. Evaluate on Test Set
Evaluate the trained model on the test set.

In [13]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 334/334 [00:07<00:00, 42.33it/s]


{'balanced_accuracy': 0.7463781895065341,
 'cohen_kappa': 0.8961779343856422,
 'loss': 0.2021648968224174}