# GAMENet Model Training on MIMIC-III Dataset

Train the GAMENet (Graph Augmented MEmory Networks) model for medication recommendation using the MIMIC-III dataset.


In [2]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()


ModuleNotFoundError: No module named 'pyhealth'

## Set Drug Recommendation Task

We use the drug recommendation task which predicts medications based on patient conditions and procedures.


In [None]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task)


Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...
Generating samples with 1 worker(s)...


Generating samples for DrugRecommendationMIMIC3 with 1 worker: 100%|██████████| 1000/1000 [00:01<00:00, 546.61it/s]

Label drugs vocab: {'*NF*': 0, '0.9%': 1, '1/2 ': 2, '5% D': 3, 'ACD-': 4, 'Acet': 5, 'Acyc': 6, 'Aden': 7, 'Albu': 8, 'Alen': 9, 'Allo': 10, 'Alpr': 11, 'Alte': 12, 'Alum': 13, 'Ambi': 14, 'Amik': 15, 'Amin': 16, 'Amio': 17, 'Amlo': 18, 'Amph': 19, 'Ampi': 20, 'Apro': 21, 'Aqua': 22, 'Arti': 23, 'Asco': 24, 'Aspi': 25, 'Aten': 26, 'Ator': 27, 'Atro': 28, 'Azit': 29, 'Becl': 30, 'Bell': 31, 'Bici': 32, 'Bisa': 33, 'Bism': 34, 'Brom': 35, 'BuPR': 36, 'Bume': 37, 'BusP': 38, 'Calc': 39, 'Capt': 40, 'Carb': 41, 'Carv': 42, 'Casp': 43, 'Cefa': 44, 'Cefe': 45, 'Ceft': 46, 'Cele': 47, 'Ceph': 48, 'Cety': 49, 'Chlo': 50, 'Cipr': 51, 'Cisa': 52, 'Cita': 53, 'Citr': 54, 'Clin': 55, 'Clob': 56, 'Clon': 57, 'Clop': 58, 'Clot': 59, 'Colc': 60, 'Coll': 61, 'Cosy': 62, 'Cyan': 63, 'D5 1': 64, 'D5NS': 65, 'D5W': 66, 'D5W ': 67, 'DOBU': 68, 'Daki': 69, 'Dapt': 70, 'Desm': 71, 'Dexa': 72, 'Dexm': 73, 'Dext': 74, 'Diaz': 75, 'Dida': 76, 'Digo': 77, 'Dilt': 78, 'Diov': 79, 'Diph': 80, 'Diva': 81, 'Dobu':


Processing samples: 100%|██████████| 51/51 [00:00<00:00, 2501.75it/s]

Generated 51 samples for task DrugRecommendationMIMIC3





## Split Dataset

Split the dataset into train, validation, and test sets using patient-level splitting.


In [None]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.15, 0.15]
)

train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)


## Initialize GAMENet Model

Create the GAMENet model with specified hyperparameters.


In [None]:
from pyhealth.models import GAMENet

model = GAMENet(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_layers=1,
    dropout=0.5,
)


  import pkg_resources


## Train Model

Train the model using the PyHealth Trainer with relevant metrics for drug recommendation.


In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="jaccard_samples",
)


GAMENet(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(303, 128)
    (procedures): Embedding(115, 128)
    (drugs_hist): Embedding(175, 128)
  ))
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayer(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=384, out_features=253, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'f1_samples', 'pr_auc_samples', 'ddi']
Device: cpu

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: Non

Epoch 0 / 5: 100%|██████████| 1/1 [00:00<00:00, 13.48it/s]

--- Train epoch-0, step-1 ---
loss: 0.6989



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 137.55it/s]


--- Eval epoch-0, step-1 ---
jaccard_samples: 0.0450
f1_samples: 0.0816
pr_auc_samples: 0.0664
ddi_score: 0.0000
loss: 0.6768
New best jaccard_samples score (0.0450) at epoch-0, step-1



Epoch 1 / 5: 100%|██████████| 1/1 [00:00<00:00, 27.67it/s]

--- Train epoch-1, step-2 ---
loss: 0.6772



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 141.89it/s]


--- Eval epoch-1, step-2 ---
jaccard_samples: 0.0450
f1_samples: 0.0816
pr_auc_samples: 0.1666
ddi_score: 0.0000
loss: 0.6528



Epoch 2 / 5: 100%|██████████| 1/1 [00:00<00:00, 27.42it/s]

--- Train epoch-2, step-3 ---
loss: 0.6552



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 93.62it/s]


--- Eval epoch-2, step-3 ---
jaccard_samples: 0.0450
f1_samples: 0.0816
pr_auc_samples: 0.2112
ddi_score: 0.0000
loss: 0.6264



Epoch 3 / 5: 100%|██████████| 1/1 [00:00<00:00, 24.14it/s]

--- Train epoch-3, step-4 ---
loss: 0.6312



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.81it/s]


--- Eval epoch-3, step-4 ---
jaccard_samples: 0.0450
f1_samples: 0.0816
pr_auc_samples: 0.2289
ddi_score: 0.0000
loss: 0.5945



Epoch 4 / 5: 100%|██████████| 1/1 [00:00<00:00, 25.92it/s]

--- Train epoch-4, step-5 ---
loss: 0.6022



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 132.22it/s]


--- Eval epoch-4, step-5 ---
jaccard_samples: 0.0450
f1_samples: 0.0816
pr_auc_samples: 0.2485
ddi_score: 0.0000
loss: 0.5580
Loaded best model


## Evaluate on Test Set

Evaluate the trained model on the test set and print the results.


In [None]:
results = trainer.evaluate(test_loader)

print("Test Set Results:")
print(f"  Jaccard (samples): {results['jaccard_samples']:.4f}")
print(f"  F1 (samples): {results['f1_samples']:.4f}")
print(f"  PR-AUC (samples): {results['pr_auc_samples']:.4f}")
ddi_value = results.get("ddi")
if ddi_value is not None:
    print(f"  DDI Rate: {ddi_value:.4f}")
else:
    print("  DDI Rate: N/A (metric not available)")


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 112.65it/s]


Test Set Results:
  Jaccard (samples): 0.1210
  F1 (samples): 0.2091
  PR-AUC (samples): 0.1661
  DDI Rate: N/A (metric not available)
