In [None]:
"""
ConCare Model Training Example on MIMIC-IV Dataset
===================================================

This example demonstrates how to train the ConCare model for in-hospital mortality
prediction using the MIMIC-IV dataset.

ConCare (Concare: Personalized clinical feature embedding via capturing the
healthcare context) is a model that uses channel-wise GRUs and multi-head
self-attention to capture feature correlations and temporal patterns in EHR data.

Reference:
    Liantao Ma et al. Concare: Personalized clinical feature embedding via
    capturing the healthcare context. AAAI 2020.
"""

 ## 1. Load MIMIC-IV Dataset

In [None]:
from pyhealth.datasets import MIMIC4Dataset

# Load MIMIC-IV dataset
# Note: Update the path to your local MIMIC-IV data directory
dataset = MIMIC4Dataset(
    ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
    ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
    dev=True,  # Set to False for full dataset
)

## 2. Define Task and Create Samples

In [None]:
from pyhealth.tasks import InHospitalMortalityMIMIC4

# Define the in-hospital mortality prediction task
task = InHospitalMortalityMIMIC4()

# Apply task to dataset and create samples
samples = dataset.set_task(
    task,
    num_workers=10,
)

## 3. Explore the Data

In [None]:
# View a sample
print("Sample structure:")
print(samples[0])

## 4. Split Dataset

In [None]:
from pyhealth.datasets import split_by_sample

# Split dataset into train, validation, and test sets
train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=samples,
    ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 5. Create Data Loaders

In [None]:
from pyhealth.datasets import get_dataloader

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## 6. Initialize ConCare Model

In [None]:
from pyhealth.models import ConCare

# Initialize ConCare model
model = ConCare(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_head=4,
    pe_hidden=64,
    dropout=0.5,
)

print(model)

 ## 7. Initialize Trainer and Evaluate Before Training

In [None]:
from pyhealth.trainer import Trainer

# Initialize trainer with ROC-AUC metric
trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy"]
)

# Evaluate model before training (baseline)
print("Baseline evaluation (before training):")
print(trainer.evaluate(test_dataloader))

## 8. Train the Model




In [None]:
# Train the model
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-4},
)

## 9. Evaluate on Test Set


In [None]:
print("\nFinal evaluation (after training):")
results = trainer.evaluate(test_dataloader)
print(results)

## 10. Save and Load Model

In [None]:
# Save model
import torch

torch.save(model.state_dict(), "./concare_mimic4_mortality.pt")
print("Model saved to ./concare_mimic4_mortality.pt")

In [None]:
# Load model (for inference)
model.load_state_dict(torch.load("./concare_mimic4_mortality.pt"))
model.eval()


## 11. Get Patient Embeddings


In [None]:
# Get patient embeddings for downstream analysis
model.eval()
with torch.no_grad():
    batch = next(iter(test_dataloader))
    batch["embed"] = True
    output = model(**batch)
    embeddings = output["embed"]
    print(f"Patient embeddings shape: {embeddings.shape}")

## Alternative: Using MIMIC-III 


In [None]:
"""
# If you want to use MIMIC-III instead, use the following code:

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import InHospitalMortalityMIMIC3

dataset = MIMIC3Dataset(
    ehr_root="/path/to/mimiciii/",
    ehr_tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS"],
    dev=True,
)

task = InHospitalMortalityMIMIC3()
samples = dataset.set_task(task, num_workers=10)

# The rest of the code remains the same
"""