## Step 1: Load the MIMIC-III Dataset

We'll load the MIMIC-III dataset using PyHealth 2's new `MIMIC3Dataset` class. We need to specify:
- `root`: Path to the MIMIC-III data directory
- `tables`: Clinical tables to load (diagnoses, procedures, prescriptions)
- `dev`: Set to `True` for development/testing with a small subset of data

In [None]:
from pyhealth.datasets import MIMIC3Dataset

# Load MIMIC-III dataset
dataset = MIMIC3Dataset(
    root=r"F:\coding_projects\pyhealth\downloads\mimic-iii-demo",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=True,  # Set to False for full dataset
)

dataset.stats()

## Step 2: Define the Mortality Prediction Task

PyHealth 2 uses task classes to define how to extract samples from the raw EHR data. The `MortalityPredictionMIMIC3` task:
- Extracts diagnosis codes (ICD-9), procedure codes, and drug information from each visit
- Creates binary labels based on mortality in the next visit
- Filters out visits without sufficient clinical codes

You can optionally specify a `cache_dir` to save processed samples for faster future loading.

In [None]:
from pyhealth.tasks import MortalityPredictionMIMIC3

# Define the mortality prediction task
task = MortalityPredictionMIMIC3()

# Apply the task to generate samples
samples = dataset.set_task(
    task=task,
    cache_dir="./cache_mortality_mimic3"  # Cache processed samples
)

print(f"Generated {len(samples)} samples")
print(f"\nInput schema: {samples.input_schema}")
print(f"Output schema: {samples.output_schema}")

## Step 3: Explore a Sample

Let's examine what a single sample looks like. Each sample represents one hospital visit with:
- **conditions**: List of ICD-9 diagnosis codes
- **procedures**: List of ICD-9 procedure codes  
- **drugs**: List of drug names
- **mortality**: Binary label (0 = survived, 1 = deceased in next visit)

In [None]:
# Display a sample
print("Sample structure:")
print(samples[0])

# Show statistics
print("\n" + "="*50)
print("Dataset Statistics:")
print("="*50)

# Count unique codes
all_conditions = set()
all_procedures = set()
all_drugs = set()
mortality_count = 0
for sample in samples:
    all_conditions.update(sample['conditions'])
    all_procedures.update(sample['procedures'])
    all_drugs.update(sample['drugs'])
    mortality_count += float(sample['mortality'])

print(f"Unique diagnosis codes: {len(all_conditions)}")
print(f"Unique procedure codes: {len(all_procedures)}")
print(f"Unique drugs: {len(all_drugs)}")
print(f"\nMortality rate: {mortality_count/len(samples)*100:.2f}%")
print(f"Positive samples: {mortality_count}")
print(f"Negative samples: {len(samples) - mortality_count}")

## Step 4: Split the Dataset

We split the data into training, validation, and test sets using a 70-10-20 split.

**Note:** We use `split_by_sample` which randomly splits samples. For time-series tasks, you might want to use temporal splits to avoid data leakage.

In [None]:
from pyhealth.datasets import split_by_sample

# Split dataset: 70% train, 10% validation, 20% test
train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=samples, 
    ratios=[0.7, 0.1, 0.2]
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## Step 5: Create Data Loaders

Data loaders batch the samples and handle data feeding during training and evaluation.

In [None]:
from pyhealth.datasets import get_dataloader

# Create data loaders
train_dataloader = get_dataloader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=2, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=2, shuffle=False)

print(f"Training batches: {len(train_dataloader)}")
print(f"Validation batches: {len(val_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")

## Step 6: Initialize the AdaCare Model

The AdaCare model in PyHealth 2 automatically handles different feature types:
- **Sequence features** (like diagnosis/procedure/drug codes) are embedded using learned embeddings
- **Multiple feature keys** are processed by separate AdaCare layers
- The model provides interpretability through attention weights

### Key Parameters:
- `embedding_dim`: Dimension of code embeddings (default: 128)
- `hidden_dim`: Hidden dimension of GRU layers (default: 128)
- `kernel_size`: Kernel size for causal convolution (default: 2)
- `kernel_num`: Number of convolution kernels (default: 64)
- `dropout`: Dropout rate for regularization (default: 0.5)

In [None]:
from pyhealth.models import AdaCare

# Initialize AdaCare model
model = AdaCare(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"\nModel architecture:")
print(model)

## Step 7: Train the Model

We use PyHealth's `Trainer` class which handles:
- Training loop with automatic batching
- Validation during training
- Model checkpointing based on validation metrics
- Early stopping

We monitor the **ROC-AUC** score on the validation set.

In [None]:
from pyhealth.trainer import Trainer

# Initialize trainer
trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy", "f1"]  # Track multiple metrics
)

# Train the model
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",  # Use ROC-AUC for model selection
    optimizer_params={"lr": 1e-3},  # Learning rate
)

## Step 8: Evaluate on Test Set

After training, we evaluate the model on the held-out test set to measure its generalization performance.

In [None]:
# Evaluate on test set
test_results = trainer.evaluate(test_dataloader)

print("\n" + "="*50)
print("Test Set Performance")
print("="*50)
for metric, value in test_results.items():
    print(f"{metric}: {value:.4f}")

## Step 9: Model Interpretability (Optional)

One of AdaCare's key features is interpretability. The model provides attention weights that indicate which features are most important for predictions.

Let's examine the feature importance for a few test samples.

In [None]:
import torch
import numpy as np

# Get a batch from test set
test_batch = next(iter(test_dataloader))

# Run model in evaluation mode
model.eval()
with torch.no_grad():
    output = model(**test_batch)

# Extract interpretability information
if 'feature_importance' in output:
    print("Feature importance available!")
    print(f"Shape: {output['feature_importance']}")
    
    # Display importance for first sample
    print("\nFeature importance for first sample:")
    print("This shows which clinical features the model focuses on.")
else:
    print("Feature importance not available in model output.")

# Display predictions
print("\n" + "="*50)
print("Sample Predictions:")
print("="*50)
predictions = output['y_prob'].cpu().numpy()
true_labels = output['y_true'].cpu().numpy()

for i in range(min(5, len(predictions))):
    pred = predictions[i][0]
    true = int(true_labels[i][0])
    print(f"Sample {i+1}: Predicted={pred:.3f}, True={true}, Prediction={'Mortality' if pred > 0.5 else 'Survival'}")