# MoleRec Model Training on MIMIC-III Dataset

Train the MoleRec model for medication recommendation on the MIMIC-III dataset.


In [None]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()


## Set Drug Recommendation Task

Use the `DrugRecommendationMIMIC3` task function which creates samples with conditions, procedures, and atc-3 codes (drugs).


In [None]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")


## Split Dataset and Create Data Loaders

Split the dataset by patient to ensure no data leakage between train/validation/test sets.


In [None]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)


## Initialize MoleRec Model


In [None]:
from pyhealth.models import MoleRec

model = MoleRec(
    dataset=samples,
    embedding_dim=64,
    hidden_dim=64,
    num_rnn_layers=1,
    num_gnn_layers=4,
    dropout=0.5,
)

print(model)


## Initialize Trainer

We use jaccard similarity, f1 score, pr_auc, ddi score.


In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)


## Train the Model

Train the model for a few epochs. I used 5 epochs here. Might need to train for more epochs in prod.


In [None]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="pr_auc_samples",
    optimizer_params={"lr": 1e-4},
)


## Evaluate on Test Set

Evaluate the trained model on the test set to see final performance metrics.


In [None]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")
