# RNN: Mortality Prediction on MIMIC-III (Baseline â€” No code_mapping)

This notebook runs the RNN model for mortality prediction **without** `code_mapping`.
Raw ICD-9 and NDC codes are used as-is for the embedding vocabulary.

**Model:** RNN (GRU)  
**Task:** In-hospital mortality prediction  
**Dataset:** Synthetic MIMIC-III (`dev=False`)

## Step 1: Load the MIMIC-III Dataset

In [None]:
import tempfile

from pyhealth.datasets import MIMIC3Dataset

base_dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    cache_dir=tempfile.TemporaryDirectory().name,
    dev=False,
)

base_dataset.stats()

## Step 2: Define the Mortality Prediction Task

In [None]:
from pyhealth.tasks import MortalityPredictionMIMIC3

task = MortalityPredictionMIMIC3()

samples = base_dataset.set_task(task)

print(f"Generated {len(samples)} samples")
print(f"\nInput schema: {samples.input_schema}")
print(f"Output schema: {samples.output_schema}")

## Step 3: Dataset Statistics

In [None]:
print("Sample structure:")
print(samples[0])

print("\n" + "=" * 50)
print("Processor Vocabulary Sizes:")
print("=" * 50)
for key, proc in samples.input_processors.items():
    if hasattr(proc, 'code_vocab'):
        print(f"{key}: {len(proc.code_vocab)} codes (including <pad>, <unk>)")

mortality_count = sum(float(s.get("mortality", 0)) for s in samples)
print(f"\nTotal samples: {len(samples)}")
print(f"Mortality rate: {mortality_count / len(samples) * 100:.2f}%")
print(f"Positive samples: {int(mortality_count)}")
print(f"Negative samples: {len(samples) - int(mortality_count)}")

## Step 4: Split the Dataset

In [None]:
from pyhealth.datasets import split_by_patient

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, [0.8, 0.1, 0.1], seed=42
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## Step 5: Create Data Loaders

In [None]:
from pyhealth.datasets import get_dataloader

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

print(f"Training batches: {len(train_dataloader)}")
print(f"Validation batches: {len(val_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")

## Step 6: Initialize the RNN Model

In [None]:
from pyhealth.models import RNN

model = RNN(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"\nModel architecture:")
print(model)

## Step 7: Train the Model

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy", "f1"],
)

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-3},
)

## Step 8: Evaluate on Test Set

In [None]:
test_results = trainer.evaluate(test_dataloader)

print("\n" + "=" * 50)
print("Test Set Performance (NO code_mapping)")
print("=" * 50)
for metric, value in test_results.items():
    print(f"{metric}: {value:.4f}")