# MoleRec Model Training on MIMIC-III Dataset

Train the MoleRec model for medication recommendation on the MIMIC-III dataset.


In [1]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()


  from .autonotebook import tqdm as notebook_tqdm


No config path provided, using default config
Initializing mimic3 dataset from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III (dev mode: True)
Scanning table: patients from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv
Some column names were converted to lowercase
Scanning table: admissions from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv
Some column names were converted to lowercase
Scanning table: icustays from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv
Some column names were converted to lowercase
Scanning table: 

## Set Drug Recommendation Task

Use the `DrugRecommendationMIMIC3` task function which creates samples with conditions, procedures, and atc-3 codes (drugs).


In [2]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")


Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...
Generating samples with 4 worker(s)...
Generating samples for DrugRecommendationMIMIC3 with 4 workers


Collecting samples for DrugRecommendationMIMIC3 from 4 workers: 100%|██████████| 1000/1000 [00:00<00:00, 2908.84it/s]

Label drugs vocab: {'*NF*': 0, '1/2 ': 1, 'Acet': 2, 'Acyc': 3, 'Albu': 4, 'Alem': 5, 'Allo': 6, 'Alte': 7, 'Alum': 8, 'Ambi': 9, 'Amio': 10, 'Amlo': 11, 'Amox': 12, 'Amph': 13, 'Arip': 14, 'Arti': 15, 'Asco': 16, 'Aspi': 17, 'Aten': 18, 'Ator': 19, 'Atov': 20, 'Atro': 21, 'Azat': 22, 'Azit': 23, 'Bisa': 24, 'BuPR': 25, 'Bupi': 26, 'Calc': 27, 'Capt': 28, 'Carv': 29, 'Casp': 30, 'Cefa': 31, 'Cefe': 32, 'Ceft': 33, 'Cele': 34, 'Cepa': 35, 'Ceph': 36, 'Chlo': 37, 'Cilo': 38, 'Cipr': 39, 'Cisa': 40, 'Cita': 41, 'Clin': 42, 'Clon': 43, 'Clop': 44, 'Colc': 45, 'Cyan': 46, 'Cycl': 47, 'D10W': 48, 'D5 1': 49, 'D5NS': 50, 'D5W': 51, 'D5W ': 52, 'DOBU': 53, 'Daki': 54, 'Desm': 55, 'Dexa': 56, 'Dexm': 57, 'Dext': 58, 'Diaz': 59, 'Digo': 60, 'Dilt': 61, 'Diph': 62, 'Dipy': 63, 'Docu': 64, 'Dola': 65, 'Done': 66, 'DopA': 67, 'Dorz': 68, 'Doxy': 69, 'Drop': 70, 'Emtr': 71, 'Enal': 72, 'Epin': 73, 'Epoe': 74, 'Epti': 75, 'Eryt': 76, 'Famo': 77, 'Fent': 78, 'Ferr': 79, 'Fexo': 80, 'Fish': 81, 'Flec':


Processing samples: 100%|██████████| 38/38 [00:00<00:00, 1234.77it/s]

Generated 38 samples for task DrugRecommendationMIMIC3





Sample Dataset Statistics:
	- Dataset: mimic3
	- Task: <pyhealth.tasks.drug_recommendation.DrugRecommendationMIMIC3 object at 0x130730bf0>
	- Number of samples: 38

First sample structure:
Patient ID: 1033
Number of visits: 1
Sample conditions (first visit): tensor([1, 2, 3, 4, 5])...
Sample procedures (first visit): tensor([1, 2, 3, 4, 0])...
Sample drugs (target): tensor([0., 0., 1., 0., 1., 1., 0., 0., 0., 0.])...


## Split Dataset and Create Data Loaders

Split the dataset by patient to ensure no data leakage between train/validation/test sets.


In [3]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)


Train samples: 26
Validation samples: 4
Test samples: 8


## Initialize MoleRec Model


In [4]:
from pyhealth.models import MoleRec

model = MoleRec(
    dataset=samples,
    embedding_dim=64,
    hidden_dim=64,
    num_rnn_layers=1,
    num_gnn_layers=4,
    dropout=0.5,
)

print(model)


  import pkg_resources


MoleRec(
  (dropout_fn): Dropout(p=0.5, inplace=False)
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(198, 64)
    (procedures): Embedding(99, 64)
    (drugs_hist): Embedding(171, 64)
  ))
  (substructure_graphs): StaticParaDict()
  (molecule_graphs): StaticParaDict()
  (rnns): ModuleDict(
    (conditions): GRU(64, 64, batch_first=True)
    (procedures): GRU(64, 64, batch_first=True)
  )
  (substructure_relation): Sequential(
    (0): ReLU()
    (1): Linear(in_features=128, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
  (layer): MoleRecLayer(
    (substructure_encoder): GINGraph(
      (atom_encoder): AtomEncoder(
        (atom_embedding_list): ModuleList(
          (0): Embedding(119, 64)
          (1): Embedding(5, 64)
          (2-3): 2 x Embedding(12, 64)
          (4): Embedding(10, 64)
          (5-6): 2 x Embedding(6, 64)
          (7-8): 2 x Embedding(2, 64)
        )
   

## Initialize Trainer

We use jaccard similarity, f1 score, pr_auc, ddi score.


In [5]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)


MoleRec(
  (dropout_fn): Dropout(p=0.5, inplace=False)
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(198, 64)
    (procedures): Embedding(99, 64)
    (drugs_hist): Embedding(171, 64)
  ))
  (substructure_graphs): StaticParaDict()
  (molecule_graphs): StaticParaDict()
  (rnns): ModuleDict(
    (conditions): GRU(64, 64, batch_first=True)
    (procedures): GRU(64, 64, batch_first=True)
  )
  (substructure_relation): Sequential(
    (0): ReLU()
    (1): Linear(in_features=128, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
  (layer): MoleRecLayer(
    (substructure_encoder): GINGraph(
      (atom_encoder): AtomEncoder(
        (atom_embedding_list): ModuleList(
          (0): Embedding(119, 64)
          (1): Embedding(5, 64)
          (2-3): 2 x Embedding(12, 64)
          (4): Embedding(10, 64)
          (5-6): 2 x Embedding(6, 64)
          (7-8): 2 x Embedding(2, 64)
        )
   

Evaluation:   0%|          | 0/1 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]


{'jaccard_samples': 0.1686518337947108, 'f1_samples': 0.2766813916994273, 'pr_auc_samples': 0.18212385799674458, 'ddi_score': 0.0, 'loss': 0.697697639465332}


## Train the Model

Train the model for a few epochs. I used 5 epochs here. Might need to train for more epochs in prod.


In [6]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="pr_auc_samples",
    optimizer_params={"lr": 1e-4},
)


Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x1318addc0>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 1/1 [00:00<00:00, 19.24it/s]

--- Train epoch-0, step-1 ---
loss: 0.7093



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 266.58it/s]


--- Eval epoch-0, step-1 ---
jaccard_samples: 0.1445
f1_samples: 0.2449
pr_auc_samples: 0.1699
ddi_score: 0.0000
loss: 0.6902
New best pr_auc_samples score (0.1699) at epoch-0, step-1



Epoch 1 / 5: 100%|██████████| 1/1 [00:00<00:00, 94.00it/s]

--- Train epoch-1, step-2 ---
loss: 0.7081



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 254.57it/s]


--- Eval epoch-1, step-2 ---
jaccard_samples: 0.1445
f1_samples: 0.2449
pr_auc_samples: 0.1698
ddi_score: 0.0000
loss: 0.6879



Epoch 2 / 5: 100%|██████████| 1/1 [00:00<00:00, 76.76it/s]

--- Train epoch-2, step-3 ---
loss: 0.6987



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 210.76it/s]


--- Eval epoch-2, step-3 ---
jaccard_samples: 0.1445
f1_samples: 0.2449
pr_auc_samples: 0.1706
ddi_score: 0.0000
loss: 0.6856
New best pr_auc_samples score (0.1706) at epoch-2, step-3



Epoch 3 / 5: 100%|██████████| 1/1 [00:00<00:00, 87.71it/s]

--- Train epoch-3, step-4 ---
loss: 0.7032



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 279.96it/s]


--- Eval epoch-3, step-4 ---
jaccard_samples: 0.1447
f1_samples: 0.2452
pr_auc_samples: 0.1698
ddi_score: 0.0000
loss: 0.6834



Epoch 4 / 5: 100%|██████████| 1/1 [00:00<00:00, 88.91it/s]

--- Train epoch-4, step-5 ---
loss: 0.6947



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 246.07it/s]


--- Eval epoch-4, step-5 ---
jaccard_samples: 0.1450
f1_samples: 0.2456
pr_auc_samples: 0.1705
ddi_score: 0.0000
loss: 0.6811
Loaded best model


## Evaluate on Test Set

Evaluate the trained model on the test set to see final performance metrics.


In [7]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 195.81it/s]


Final test set performance:
{'jaccard_samples': 0.16870862480379736, 'f1_samples': 0.2767764829870635, 'pr_auc_samples': 0.18286476395431256, 'ddi_score': 0.0, 'loss': 0.6917362809181213}

Key Metrics:
  PR-AUC: 0.1829
  F1 Score: 0.2768
  Jaccard: 0.1687
  DDI Rate: 0.0000 (lower is better)
