# GAMENet Model Training on MIMIC-III Dataset

Train the GAMENet (Graph Augmented MEmory Networks) model for medication recommendation using the MIMIC-III dataset.


In [None]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()


## Set Drug Recommendation Task

We use the drug recommendation task which predicts medications based on patient conditions and procedures.


In [None]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task)


## Split Dataset

Split the dataset into train, validation, and test sets using patient-level splitting.


In [None]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.15, 0.15]
)

train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)


## Initialize GAMENet Model

Create the GAMENet model with specified hyperparameters.


In [None]:
from pyhealth.models import GAMENet

model = GAMENet(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_layers=1,
    dropout=0.5,
)


## Train Model

Train the model using the PyHealth Trainer with relevant metrics for drug recommendation.


In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="jaccard_samples",
)


## Evaluate on Test Set

Evaluate the trained model on the test set and print the results.


In [None]:
results = trainer.evaluate(test_loader)

print("Test Set Results:")
print(f"  Jaccard (samples): {results['jaccard_samples']:.4f}")
print(f"  F1 (samples): {results['f1_samples']:.4f}")
print(f"  PR-AUC (samples): {results['pr_auc_samples']:.4f}")
ddi_value = results.get("ddi")
if ddi_value is not None:
    print(f"  DDI Rate: {ddi_value:.4f}")
else:
    print("  DDI Rate: N/A (metric not available)")
