# Example: Wrapping Custom Sequential Data for PyHealth

This notebook demonstrates how preprocessed sequential data (like synthetic EHR data) can be wrapped using a custom class inheriting from `pyhealth.datasets.SampleEHRDataset`.

This allows custom data formats to be used within the PyHealth ecosystem, for example, as input to uncertainty quantification studies or other downstream tasks.

**Focus:** The primary goal here is to show the data wrapping mechanism using `CustomSequentialEHRDataPyHealth` from the accompanying `.py` file. The actual model training or analysis part is omitted for brevity in this specific example, but this dataset could be used as input for such tasks.

In [None]:
import torch
import numpy as np
import os
from pyhealth_custom_dataset_wrapper import CustomSequentialEHRDataPyHealth

## 1. Generate Minimal Synthetic Sequential Data

First, we create some dummy sequential data (sequences and labels) in the format our wrapper expects: lists of PyTorch tensors. In a real scenario, this data would come from your preprocessing pipeline.

In [None]:
num_patients = 20
input_dim = 3


patient_sequences = []
for _ in range(num_patients):
    seq_len = np.random.randint(5, 15) 
    sequence = torch.randn(seq_len, input_dim) 
    patient_sequences.append(sequence)


patient_labels = [torch.tensor([float(np.random.rand() > 0.5)]) for _ in range(num_patients)]

print(f"Generated {len(patient_sequences)} synthetic patient sequences.")
if patient_sequences:
    print(f"Example sequence shape: {patient_sequences[0].shape}")
    print(f"Example label: {patient_labels[0]}")

## 2. Wrap Data using PyHealth Custom Dataset Wrapper

Now, we instantiate our `CustomSequentialEHRDataPyHealth` class (which uses `pyhealth.datasets.SampleEHRDataset` internally) with the generated data lists.

In [None]:
print("\n--- Wrapping Data with PyHealth Custom Dataset Class ---")
try:
    # Instantiate the PyHealth-compatible dataset wrapper
    pyhealth_dataset = CustomSequentialEHRDataPyHealth(
        list_of_patient_sequences=patient_sequences,
        list_of_patient_labels=patient_labels,
        root="." # Use current directory for any potential caching by PyHealth
    )
    print(f"Successfully wrapped data into CustomSequentialEHRDataPyHealth.")
    print(f"Number of samples in PyHealth dataset: {len(pyhealth_dataset)}")

    if len(pyhealth_dataset) > 0:
        # Demonstrate getting a sample (processed by the task_fn)
        first_sample = pyhealth_dataset[0] 
        print("\nExample of first sample retrieved via the PyHealth wrapper:")
        print(f"  Patient ID: {first_sample['patient_id']}")
        print(f"  Sequence Data Shape: {first_sample['sequence_data'].shape}")
        print(f"  Label: {first_sample['label']}")

except Exception as e:
    print(f"An error occurred during dataset wrapping: {e}")

print("--- PyHealth Dataset Wrapper Demonstration Complete ---")

## 3. Next Steps

From here, the `pyhealth_dataset` object could potentially be used with PyHealth's data loaders or models if they are compatible with the output format defined in our `basic_task_fn`.

Alternatively, this example primarily serves to show data integration. One could continue using standard PyTorch DataLoaders (created from the original `patient_sequences` and `patient_labels`) and custom models for downstream tasks like model training and uncertainty analysis, while having demonstrated compatibility with PyHealth's data structures.