## 1. Environment Setup
Seed RNGs, import dependencies, and choose a device.

In [None]:
import random

import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

## 2. Load TUEV Dataset
Point `root` to the `edf/` directory of TUEV.

In [None]:
from pyhealth.datasets import TUEVDataset

# Example relative path (from this notebook's folder):
#   <repo_root>/downloads/tuev/v2.0.1/edf
# Update as needed.
dataset = TUEVDataset(
    root="../../downloads/tuev/v2.0.1/edf",
    subset="both",  # 'train', 'eval', or 'both'
)
dataset.stats()

## 3. Prepare Task Dataset
Apply the `EEGEventsTUEV` task to produce one sample per annotated event.

In [None]:
from pyhealth.tasks import EEGEventsTUEV

sample_dataset = dataset.set_task(EEGEventsTUEV())

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root.")

sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

## 4. Split Dataset
Split into train/val/test and build dataloaders.

In [None]:
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader

BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

## 5. Inspect Batch Structure
Check the first batch to confirm shapes match what `ContraWR` expects: `(batch, channels, length)`.

In [None]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, 'shape'):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

print('Batch structure:')
for key, value in first_batch.items():
    print(f"  {key}: {describe(value)}")

## 6. Instantiate ContraWR
Create the `ContraWR` model using the task dataset (it infers feature/label keys from the dataset schema).

In [None]:
from pyhealth.models import ContraWR

model = ContraWR(
    dataset=sample_dataset,
    n_fft=128,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_keys}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Test Forward Pass
Run a no-grad forward pass and verify the loss/outputs.

In [None]:
test_batch = {
    key: (value.to(device) if hasattr(value, 'to') else value)
    for key, value in first_batch.items()
}

model.eval()
with torch.no_grad():
    output = model(**test_batch)

print('Model output keys:', output.keys())
print(f"Loss: {output['loss'].item():.4f}")
print(f"y_prob shape: {tuple(output['y_prob'].shape)}")
print(f"y_true shape: {tuple(output['y_true'].shape)}")

## 8. Train with PyHealth Trainer
Train `ContraWR` on the TUEV EEG events task.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    device=str(device),
    enable_logging=False,
)

training_config = {
    'epochs': 10,
    'optimizer_params': {'lr': 1e-3},
    'max_grad_norm': 5.0,
    # 'monitor': 'accuracy',  # uncomment to track a metric on val
    # 'monitor_criterion': 'max',
}

In [None]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

## 9. Evaluate on Test Set
Compute multiclass metrics on the held-out test split.

In [None]:
if test_loader is None:
    raise RuntimeError('No test dataloader was created.')

scores = trainer.evaluate(test_loader)
print('Test scores:')
for k, v in scores.items():
    print(f"  {k}: {v:.4f}")

## 10. Save Model (Optional)
Save the trained weights to a checkpoint file.

In [None]:
save_path = 'contrawr_tuev_model.pth'
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'config': training_config,
    },
    save_path,
)
print(f"Model saved to: {save_path}")