# SafeDrug Model Training on MIMIC-III Dataset

Train the SafeDrug model for medication recommendation on the MIMIC-III dataset.



In [1]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()



  from .autonotebook import tqdm as notebook_tqdm


No config path provided, using default config
Initializing mimic3 dataset from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III (dev mode: True)
Scanning table: patients from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv
Some column names were converted to lowercase
Scanning table: admissions from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv
Some column names were converted to lowercase
Scanning table: icustays from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv
Some column names were converted to lowercase
Scanning table: 

## Set Drug Recommendation Task

Use the `DrugRecommendationMIMIC3` task function which creates samples with conditions, procedures, and atc-3 codes (drugs).



In [2]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")



Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...
Generating samples with 4 worker(s)...
Generating samples for DrugRecommendationMIMIC3 with 4 workers


Collecting samples for DrugRecommendationMIMIC3 from 4 workers: 100%|██████████| 1000/1000 [00:00<00:00, 4391.46it/s]

Label drugs vocab: {'*NF*': 0, '1/2 ': 1, 'Acet': 2, 'Albu': 3, 'Allo': 4, 'Alte': 5, 'Alum': 6, 'Ambi': 7, 'Amio': 8, 'Amlo': 9, 'Ampi': 10, 'Arip': 11, 'Arti': 12, 'Asco': 13, 'Aspi': 14, 'Aten': 15, 'Ator': 16, 'Atov': 17, 'Atro': 18, 'Aztr': 19, 'Baci': 20, 'Bacl': 21, 'Bisa': 22, 'Brim': 23, 'BuPR': 24, 'Bume': 25, 'Bupi': 26, 'Calc': 27, 'Capt': 28, 'Casp': 29, 'Cefa': 30, 'Cefe': 31, 'Ceft': 32, 'Cepa': 33, 'Ceph': 34, 'Cety': 35, 'Chlo': 36, 'Cipr': 37, 'Cisa': 38, 'Cita': 39, 'Clin': 40, 'Clob': 41, 'Clon': 42, 'Code': 43, 'Cosy': 44, 'Cyan': 45, 'Cycl': 46, 'D5 1': 47, 'D5NS': 48, 'D5W': 49, 'D5W ': 50, 'DOBU': 51, 'DOPa': 52, 'Dapt': 53, 'Dexa': 54, 'Dexm': 55, 'Dext': 56, 'Diaz': 57, 'Dilt': 58, 'Diph': 59, 'Docu': 60, 'Dola': 61, 'Doxe': 62, 'Doxy': 63, 'Dulo': 64, 'Enal': 65, 'Enox': 66, 'Epin': 67, 'Epoe': 68, 'Epti': 69, 'Ergo': 70, 'Eryt': 71, 'Esmo': 72, 'Famo': 73, 'Felb': 74, 'Fent': 75, 'Ferr': 76, 'Fish': 77, 'Fluc': 78, 'Flut': 79, 'FoLI': 80, 'Foli': 81, 'Furo':


Processing samples: 100%|██████████| 30/30 [00:00<00:00, 785.20it/s]

Generated 30 samples for task DrugRecommendationMIMIC3





Sample Dataset Statistics:
	- Dataset: mimic3
	- Task: <pyhealth.tasks.drug_recommendation.DrugRecommendationMIMIC3 object at 0x12f073d70>
	- Number of samples: 30

First sample structure:
Patient ID: 27502
Number of visits: 1
Sample conditions (first visit): tensor([1, 2, 3, 4, 5])...
Sample procedures (first visit): tensor([1, 2, 3, 4, 5])...
Sample drugs (target): tensor([0., 1., 1., 1., 0., 0., 0., 0., 0., 0.])...


## Split Dataset and Create Data Loaders

Split the dataset by patient to ensure no data leakage between train/validation/test sets.



In [3]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)



Train samples: 20
Validation samples: 2
Test samples: 8


## Initialize SafeDrug Model


In [4]:
from pyhealth.models import SafeDrug

model = SafeDrug(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_layers=1,
    dropout=0.5,
)

print(model)



  import pkg_resources


SafeDrug(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(202, 128)
    (procedures): Embedding(81, 128)
    (drugs_hist): Embedding(146, 128)
  ))
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=189, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=189, out_features=189, bias=True)
    (mpnn_layernorm): LayerNorm((189,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=189, bias=True)
    (los



## Initialize Trainer

We use jaccard similarity, f1 score, pr_auc, ddi score.



In [5]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)



SafeDrug(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(202, 128)
    (procedures): Embedding(81, 128)
    (drugs_hist): Embedding(146, 128)
  ))
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=189, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=189, out_features=189, bias=True)
    (mpnn_layernorm): LayerNorm((189,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=189, bias=True)
    (los

Evaluation: 100%|██████████| 1/1 [00:01<00:00,  1.75s/it]


{'jaccard_samples': 0.20634920634920634, 'f1_samples': 0.3223905953620956, 'pr_auc_samples': 0.20634920634920634, 'ddi_score': 0.0, 'loss': 0.6931473016738892}


## Train the Model

Train the model for a few epochs. I used 5 epochs here. Might need to train for more epochs in prod.



In [6]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="pr_auc_samples",
    optimizer_params={"lr": 1e-4},
)



Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x12f09cf20>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 1/1 [00:02<00:00,  2.35s/it]

--- Train epoch-0, step-1 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 211.14it/s]


--- Eval epoch-0, step-1 ---
jaccard_samples: 0.0238
f1_samples: 0.0465
pr_auc_samples: 0.0455
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.0455) at epoch-0, step-1



Epoch 1 / 5: 100%|██████████| 1/1 [00:00<00:00, 26.44it/s]

--- Train epoch-1, step-2 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 170.98it/s]


--- Eval epoch-1, step-2 ---
jaccard_samples: 0.0238
f1_samples: 0.0465
pr_auc_samples: 0.0459
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.0459) at epoch-1, step-2



Epoch 2 / 5: 100%|██████████| 1/1 [00:00<00:00, 47.51it/s]

--- Train epoch-2, step-3 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 238.79it/s]


--- Eval epoch-2, step-3 ---
jaccard_samples: 0.0238
f1_samples: 0.0465
pr_auc_samples: 0.0474
ddi_score: 0.0000
loss: 0.6930
New best pr_auc_samples score (0.0474) at epoch-2, step-3



Epoch 3 / 5: 100%|██████████| 1/1 [00:00<00:00, 49.67it/s]

--- Train epoch-3, step-4 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 204.54it/s]


--- Eval epoch-3, step-4 ---
jaccard_samples: 0.0238
f1_samples: 0.0465
pr_auc_samples: 0.0472
ddi_score: 0.0000
loss: 0.6930



Epoch 4 / 5: 100%|██████████| 1/1 [00:00<00:00, 50.79it/s]

--- Train epoch-4, step-5 ---
loss: 0.6930



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 242.12it/s]


--- Eval epoch-4, step-5 ---
jaccard_samples: 0.0238
f1_samples: 0.0465
pr_auc_samples: 0.0473
ddi_score: 0.0000
loss: 0.6930
Loaded best model


## Evaluate on Test Set

Evaluate the trained model on the test set to see final performance metrics.

I get DDI = 0. That is abnormally low. (In a good way, I guess?)

In [7]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 132.90it/s]


Final test set performance:
{'jaccard_samples': 0.20634920634920634, 'f1_samples': 0.3223905953620956, 'pr_auc_samples': 0.3572734935051531, 'ddi_score': 0.0, 'loss': 0.6930726766586304}

Key Metrics:
  PR-AUC: 0.3573
  F1 Score: 0.3224
  Jaccard: 0.2063
  DDI Rate: 0.0000 (lower is better)
