# 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [None]:
import random

import numpy as np
import torch

from pyhealth.datasets import SleepEDFDataset
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# 2. Load Sleep-EDF Dataset
Point to the Sleep-EDF dataset root and load the telemetry subset for sleep stage classification.

In [None]:
dataset = SleepEDFDataset(
    root="../downloads/sleepedf",  # Update this path
)
dataset.stats()

# 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for self-supervised learning.

In [None]:
sample_dataset = dataset.set_task()

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root.")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")

# 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [None]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

# 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and data structure.

In [None]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

# 6. Instantiate ContraWR Model
Create the PyHealth ContraWR model for self-supervised learning on sleep signals and review its parameter footprint.

In [None]:
from pyhealth.models import ContraWR

model = ContraWR(
    dataset=sample_dataset,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_keys}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# 7. Test Forward Pass
Verify the model can process a batch and compute the contrastive loss.

In [None]:
# Move batch to device
test_batch = {key: value.to(device) if hasattr(value, 'to') else value 
              for key, value in first_batch.items()}

# Forward pass
with torch.no_grad():
    output = model(**test_batch)

print("Model output keys:", output.keys())
if 'loss' in output:
    print(f"Loss value: {output['loss'].item():.4f}")
if 'y_prob' in output:
    print(f"Output probability shape: {output['y_prob'].shape}")

# 8. Configure Trainer
Wrap the model with the PyHealth Trainer and define optimization hyperparameters.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    device=str(device),
    enable_logging=False,
)

training_config = {
    "epochs": 10,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
}

# 9. Train the Model
Launch the self-supervised training loop to learn representations from sleep signal data.

In [None]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

# 10. Save Model (Optional)
Save the trained model for future use or fine-tuning on downstream tasks.

In [None]:
# Save model checkpoint
save_path = "contrawr_sleepedf_model.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'config': training_config,
}, save_path)
print(f"Model saved to: {save_path}")

# To load the model later:
# checkpoint = torch.load(save_path)
# model.load_state_dict(checkpoint['model_state_dict'])