## 1. Environment Setup
Seed RNGs, import dependencies, and choose a device.

In [4]:
import random

import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cpu


## 2. Load TUEV Dataset
Point `root` to the `edf/` directory of TUEV.

In [1]:
from pyhealth.datasets import TUEVDataset

# Example relative path (from this notebook's folder):
#   <repo_root>/downloads/tuev/v2.0.1/edf
# Update as needed.
dataset = TUEVDataset(
    root="../../downloads/tuev/v2.0.1/edf",
    subset="both",  # 'train', 'eval', or 'both'
)
dataset.stats()

No config path provided, using default config
Using both train and eval subsets
Initializing tuev dataset from ../../downloads/tuev/v2.0.1/edf (dev mode: False)
No cache_dir provided. Using default cache dir: C:\Users\Razin\AppData\Local\pyhealth\pyhealth\Cache\743c627e-3664-5ebd-aa5c-dbbe789f5e83
Dataset: tuev
Dev mode: False
Number of patients: 4
Number of events: 4


## 3. Prepare Task Dataset
Apply the `EEGEventsTUEV` task to produce one sample per annotated event.

In [2]:
from pyhealth.tasks import EEGEventsTUEV

sample_dataset = dataset.set_task(EEGEventsTUEV())

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root.")

sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

Setting task EEG_events for tuev base dataset...
Applying task transformations on data with 1 workers...
Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 4 patients. (Polars threads: 8)


  0%|          | 0/4 [00:00<?, ?it/s]

Rank 0 inferred the following `['bytes']` data format.


100%|██████████| 4/4 [00:21<00:00,  5.27s/it]

Worker 0 finished processing patients.
Fitting processors on the dataset...





Label label vocab: {0: 0, 3: 1, 4: 2, 5: 3}
Processing samples and saving to C:\Users\Razin\AppData\Local\pyhealth\pyhealth\Cache\743c627e-3664-5ebd-aa5c-dbbe789f5e83\tasks\EEG_events\samples_896eeb17-2673-41d9-b878-99d335df68c4.ld...
Applying processors on data with 1 workers...
Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 1395 samples. (0 to 1395)


  0%|          | 0/1395 [00:00<?, ?it/s]

Rank 0 inferred the following `['str', 'pickle', 'tensor', 'int', 'tensor']` data format.


100%|██████████| 1395/1395 [00:01<00:00, 1363.33it/s]

Worker 0 finished processing samples.





Cached processed samples to C:\Users\Razin\AppData\Local\pyhealth\pyhealth\Cache\743c627e-3664-5ebd-aa5c-dbbe789f5e83\tasks\EEG_events\samples_896eeb17-2673-41d9-b878-99d335df68c4.ld
Total task samples: 1395
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}

Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])
Signal shape: torch.Size([16, 1280])
Label: 3


## 4. Split Dataset
Split into train/val/test and build dataloaders.

In [5]:
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

Train/Val/Test sizes: 976, 140, 279


## 5. Inspect Batch Structure
Check the first batch to confirm shapes match what `ContraWR` expects: `(batch, channels, length)`.

In [6]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, 'shape'):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

print('Batch structure:')
for key, value in first_batch.items():
    print(f"  {key}: {describe(value)}")

Batch structure:
  patient_id: list(len=32)
  signal_file: list(len=32)
  signal: Tensor(shape=(32, 16, 1280))
  offending_channel: list(len=32)
  label: Tensor(shape=(32,))


## 6. Instantiate ContraWR
Create the `ContraWR` model using the task dataset (it infers feature/label keys from the dataset schema).

In [7]:
from pyhealth.models import ContraWR

model = ContraWR(
    dataset=sample_dataset,
    n_fft=128,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_keys}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


=== Input data dimensions ===
n_channels: 16
length: 1280
=== Spectrogram parameters ===
n_channels: 16
freq_dim: 65
time_steps: 37
=== Convolution Parameters ===
in_channels: 16, out_channels: 32, freq_dim: 16, time_steps: 9
in_channels: 32, out_channels: 64, freq_dim: 4, time_steps: 2

Feature keys: ['signal']
Label key: ['label']
Model mode: multiclass
Total parameters: 95,076
Trainable parameters: 95,076


## 7. Test Forward Pass
Run a no-grad forward pass and verify the loss/outputs.

In [8]:
test_batch = {
    key: (value.to(device) if hasattr(value, 'to') else value)
    for key, value in first_batch.items()
}

model.eval()
with torch.no_grad():
    output = model(**test_batch)

print('Model output keys:', output.keys())
print(f"Loss: {output['loss'].item():.4f}")
print(f"y_prob shape: {tuple(output['y_prob'].shape)}")
print(f"y_true shape: {tuple(output['y_true'].shape)}")

Model output keys: dict_keys(['loss', 'y_prob', 'y_true', 'logit'])
Loss: 46.9889
y_prob shape: (32, 4)
y_true shape: (32,)


## 8. Train with PyHealth Trainer
Train `ContraWR` on the TUEV EEG events task.

In [9]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    device=str(device),
    enable_logging=False,
)

training_config = {
    'epochs': 10,
    'optimizer_params': {'lr': 1e-3},
    'max_grad_norm': 5.0,
    # 'monitor': 'accuracy',  # uncomment to track a metric on val
    # 'monitor_criterion': 'max',
}

ContraWR(
  (encoder): Sequential(
    (0): ResBlock2D(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ELU(alpha=1.0)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (downsampler): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): ResBlock2D(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): 

In [10]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: 5.0
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x0000024B1631A0E0>
Monitor: None
Monitor criterion: max
Epochs: 10
Patience: None



Epoch 0 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-0, step-31 ---
loss: 0.6033


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 22.73it/s]

--- Eval epoch-0, step-31 ---
accuracy: 0.8857
f1_macro: 0.4340
f1_micro: 0.8857
loss: 0.3663






Epoch 1 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-1, step-62 ---
loss: 0.4586


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 22.42it/s]

--- Eval epoch-1, step-62 ---
accuracy: 0.9286
f1_macro: 0.8295
f1_micro: 0.9286
loss: 0.2543






Epoch 2 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-2, step-93 ---
loss: 0.3077


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.88it/s]

--- Eval epoch-2, step-93 ---
accuracy: 0.9214
f1_macro: 0.8957
f1_micro: 0.9214
loss: 0.2323






Epoch 3 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-3, step-124 ---
loss: 0.2483


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 25.91it/s]

--- Eval epoch-3, step-124 ---
accuracy: 0.9214
f1_macro: 0.9086
f1_micro: 0.9214
loss: 0.1891






Epoch 4 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-4, step-155 ---
loss: 0.1782


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 22.73it/s]

--- Eval epoch-4, step-155 ---
accuracy: 0.9286
f1_macro: 0.9144
f1_micro: 0.9286
loss: 0.1460






Epoch 5 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-5, step-186 ---
loss: 0.1769


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.75it/s]

--- Eval epoch-5, step-186 ---
accuracy: 0.9214
f1_macro: 0.8957
f1_micro: 0.9214
loss: 0.1793






Epoch 6 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-6, step-217 ---
loss: 0.1239


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.27it/s]

--- Eval epoch-6, step-217 ---
accuracy: 0.9714
f1_macro: 0.9808
f1_micro: 0.9714
loss: 0.1368






Epoch 7 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-7, step-248 ---
loss: 0.1546


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.51it/s]

--- Eval epoch-7, step-248 ---
accuracy: 0.9714
f1_macro: 0.9808
f1_micro: 0.9714
loss: 0.0995






Epoch 8 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-8, step-279 ---
loss: 0.1313


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.51it/s]

--- Eval epoch-8, step-279 ---
accuracy: 0.9714
f1_macro: 0.9808
f1_micro: 0.9714
loss: 0.1059






Epoch 9 / 10:   0%|          | 0/31 [00:00<?, ?it/s]

--- Train epoch-9, step-310 ---
loss: 0.1139


Evaluation: 100%|██████████| 5/5 [00:00<00:00, 24.75it/s]

--- Eval epoch-9, step-310 ---
accuracy: 0.9786
f1_macro: 0.8637
f1_micro: 0.9786
loss: 0.0581





## 9. Evaluate on Test Set
Compute multiclass metrics on the held-out test split.

In [11]:
if test_loader is None:
    raise RuntimeError('No test dataloader was created.')

scores = trainer.evaluate(test_loader)
print('Test scores:')
for k, v in scores.items():
    print(f"  {k}: {v:.4f}")

Evaluation: 100%|██████████| 9/9 [00:00<00:00, 19.74it/s]


Test scores:
  accuracy: 0.9749
  f1_macro: 0.9491
  f1_micro: 0.9749
  loss: 0.0824


## 10. Save Model (Optional)
Save the trained weights to a checkpoint file.

In [12]:
save_path = 'contrawr_tuev_model.pth'
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'config': training_config,
    },
    save_path,
)
print(f"Model saved to: {save_path}")

Model saved to: contrawr_tuev_model.pth
