# 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [1]:
import random

import numpy as np
import torch

from pyhealth.datasets import SleepEDFDataset
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cpu


# 2. Load Sleep-EDF Dataset
Point to the Sleep-EDF dataset root and load the telemetry subset for sleep stage classification.

In [2]:
dataset = SleepEDFDataset(
    root="../downloads/sleepedf",  # Update this path
)
dataset.stats()

No config path provided, using default config
Initializing sleepedf dataset from ../downloads/sleepedf (dev mode: False)
Scanning table: recordings from F:\coding_projects\pyhealth\downloads\sleepedf\sleepedf-cassette-pyhealth.csv
Collecting global event dataframe...
Collected dataframe with shape: (2, 9)
Dataset: sleepedf
Dev mode: False
Number of patients: 1
Number of events: 2


# 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for self-supervised learning.

In [3]:
sample_dataset = dataset.set_task()

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root.")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")

Setting task SleepStaging for sleepedf base dataset...
Generating samples with 1 worker(s)...


Generating samples for SleepStaging with 1 worker:   0%|          | 0/1 [00:00<?, ?it/s]

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
2650 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 2650 events and 3000 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
2829 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 2829 events and 3000 original time points ...
0 bad epochs dropped


Generating samples for SleepStaging with 1 worker: 100%|██████████| 1/1 [00:14<00:00, 14.03s/it]

Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}



Processing samples: 100%|██████████| 5479/5479 [00:00<00:00, 11703.09it/s]

Generated 5479 samples for task SleepStaging
Total task samples: 5479
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}

Sample keys: dict_keys(['patient_id', 'night', 'patient_age', 'patient_sex', 'signal', 'label'])
Signal shape: torch.Size([7, 3000])





# 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [4]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

Train/Val/Test sizes: 3835, 548, 1096


# 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and data structure.

In [5]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

Batch structure:
  patient_id: list(len=32)
  night: list(len=32)
  patient_age: list(len=32)
  patient_sex: list(len=32)
  signal: Tensor(shape=(32, 7, 3000))
  label: Tensor(shape=(32,))


# 6. Instantiate ContraWR Model
Create the PyHealth ContraWR model for self-supervised learning on sleep signals and review its parameter footprint.

In [6]:
from pyhealth.models import ContraWR

model = ContraWR(
    dataset=sample_dataset,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_keys}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

  from .autonotebook import tqdm as notebook_tqdm



=== Input data dimensions ===
n_channels: 7
length: 3000
=== Spectrogram parameters ===
n_channels: 7
freq_dim: 65
time_steps: 90
=== Convolution Parameters ===
in_channels: 7, out_channels: 8, freq_dim: 16, time_steps: 22
in_channels: 8, out_channels: 16, freq_dim: 4, time_steps: 5
in_channels: 16, out_channels: 32, freq_dim: 1, time_steps: 1

Feature keys: ['signal']
Label key: ['label']
Model mode: multiclass
Total parameters: 25,326
Trainable parameters: 25,326


# 7. Test Forward Pass
Verify the model can process a batch and compute the contrastive loss.

In [7]:
# Move batch to device
test_batch = {key: value.to(device) if hasattr(value, 'to') else value 
              for key, value in first_batch.items()}

# Forward pass
with torch.no_grad():
    output = model(**test_batch)

print("Model output keys:", output.keys())
if 'loss' in output:
    print(f"Loss value: {output['loss'].item():.4f}")
if 'y_prob' in output:
    print(f"Output probability shape: {output['y_prob'].shape}")

Model output keys: dict_keys(['loss', 'y_prob', 'y_true', 'logit'])
Loss value: 2.2494
Output probability shape: torch.Size([32, 6])


# 8. Configure Trainer
Wrap the model with the PyHealth Trainer and define optimization hyperparameters.

In [8]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    device=str(device),
    enable_logging=False,
)

training_config = {
    "epochs": 10,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
}

ContraWR(
  (encoder): Sequential(
    (0): ResBlock2D(
      (conv1): Conv2d(7, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ELU(alpha=1.0)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (downsampler): Sequential(
        (0): Conv2d(7, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): ResBlock2D(
      (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ELU(alpha=

# 9. Train the Model
Launch the self-supervised training loop to learn representations from sleep signal data.

In [9]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: 5.0
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x0000029B647EE2F0>
Monitor: None
Monitor criterion: max
Epochs: 10
Patience: None



Epoch 0 / 10: 100%|██████████| 120/120 [00:06<00:00, 18.71it/s]

--- Train epoch-0, step-120 ---
loss: 0.9076



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 36.13it/s]

--- Eval epoch-0, step-120 ---
accuracy: 0.7646
f1_macro: 0.2430
f1_micro: 0.7646
loss: 0.6233




Epoch 1 / 10: 100%|██████████| 120/120 [00:06<00:00, 18.55it/s]

--- Train epoch-1, step-240 ---
loss: 0.6502



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 36.84it/s]

--- Eval epoch-1, step-240 ---
accuracy: 0.7737
f1_macro: 0.2510
f1_micro: 0.7737
loss: 0.6300




Epoch 2 / 10: 100%|██████████| 120/120 [00:06<00:00, 19.80it/s]

--- Train epoch-2, step-360 ---
loss: 0.6182



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 33.93it/s]

--- Eval epoch-2, step-360 ---
accuracy: 0.7810
f1_macro: 0.2533
f1_micro: 0.7810
loss: 0.6012




Epoch 3 / 10: 100%|██████████| 120/120 [00:06<00:00, 19.66it/s]

--- Train epoch-3, step-480 ---
loss: 0.5995



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 37.76it/s]

--- Eval epoch-3, step-480 ---
accuracy: 0.7755
f1_macro: 0.2516
f1_micro: 0.7755
loss: 0.5852




Epoch 4 / 10: 100%|██████████| 120/120 [00:06<00:00, 19.97it/s]

--- Train epoch-4, step-600 ---
loss: 0.5886



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 36.99it/s]

--- Eval epoch-4, step-600 ---
accuracy: 0.7974
f1_macro: 0.2598
f1_micro: 0.7974
loss: 0.5643




Epoch 5 / 10: 100%|██████████| 120/120 [00:05<00:00, 20.48it/s]

--- Train epoch-5, step-720 ---
loss: 0.5647



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 38.63it/s]

--- Eval epoch-5, step-720 ---
accuracy: 0.7920
f1_macro: 0.2941
f1_micro: 0.7920
loss: 0.5486




Epoch 6 / 10: 100%|██████████| 120/120 [00:06<00:00, 19.77it/s]

--- Train epoch-6, step-840 ---
loss: 0.5616



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 25.33it/s]

--- Eval epoch-6, step-840 ---
accuracy: 0.8011
f1_macro: 0.2607
f1_micro: 0.8011
loss: 0.5348




Epoch 7 / 10: 100%|██████████| 120/120 [00:07<00:00, 15.54it/s]

--- Train epoch-7, step-960 ---
loss: 0.5567



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 33.70it/s]

--- Eval epoch-7, step-960 ---
accuracy: 0.7956
f1_macro: 0.2559
f1_micro: 0.7956
loss: 0.5355




Epoch 8 / 10: 100%|██████████| 120/120 [00:05<00:00, 20.23it/s]

--- Train epoch-8, step-1080 ---
loss: 0.5500



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 35.01it/s]

--- Eval epoch-8, step-1080 ---
accuracy: 0.7920
f1_macro: 0.3068
f1_micro: 0.7920
loss: 0.5316




Epoch 9 / 10: 100%|██████████| 120/120 [00:06<00:00, 19.92it/s]

--- Train epoch-9, step-1200 ---
loss: 0.5496



Evaluation: 100%|██████████| 18/18 [00:00<00:00, 38.01it/s]

--- Eval epoch-9, step-1200 ---
accuracy: 0.7920
f1_macro: 0.3132
f1_micro: 0.7920
loss: 0.5272





# 10. Save Model (Optional)
Save the trained model for future use or fine-tuning on downstream tasks.

In [None]:
# Save model checkpoint
save_path = "contrawr_sleepedf_model.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'config': training_config,
}, save_path)
print(f"Model saved to: {save_path}")

# To load the model later:
# checkpoint = torch.load(save_path)
# model.load_state_dict(checkpoint['model_state_dict'])

Model saved to: contrawr_sleepedf_model.pth
