# SafeDrug Model Training on MIMIC-III Dataset

Train the SafeDrug model for medication recommendation on the MIMIC-III dataset.



In [1]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()



  from .autonotebook import tqdm as notebook_tqdm


No config path provided, using default config
Initializing mimic3 dataset from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III (dev mode: True)
Scanning table: patients from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv
Some column names were converted to lowercase
Scanning table: admissions from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv
Some column names were converted to lowercase
Scanning table: icustays from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv
Some column names were converted to lowercase
Scanning table: 

## Set Drug Recommendation Task

Use the `DrugRecommendationMIMIC3` task function which creates samples with conditions, procedures, and atc-3 codes (drugs).



In [2]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")



Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...
Generating samples with 4 worker(s)...
Generating samples for DrugRecommendationMIMIC3 with 4 workers


Collecting samples for DrugRecommendationMIMIC3 from 4 workers: 100%|██████████| 1000/1000 [00:00<00:00, 1420.70it/s]

Label drugs vocab: {'*NF*': 0, '1/2 ': 1, 'Acar': 2, 'Acet': 3, 'Acyc': 4, 'Albu': 5, 'Allo': 6, 'Alpr': 7, 'Alte': 8, 'Alum': 9, 'Amin': 10, 'Amio': 11, 'Amlo': 12, 'Amph': 13, 'Ampi': 14, 'Arga': 15, 'Arti': 16, 'Aspi': 17, 'Aten': 18, 'Ator': 19, 'Atro': 20, 'Augm': 21, 'Azit': 22, 'Baci': 23, 'Bisa': 24, 'Bume': 25, 'BusP': 26, 'Busp': 27, 'Calc': 28, 'Capt': 29, 'Carv': 30, 'Cefa': 31, 'Cefe': 32, 'Ceft': 33, 'Ceph': 34, 'Chlo': 35, 'Cipr': 36, 'Cisa': 37, 'Cita': 38, 'Clop': 39, 'Clot': 40, 'Cosy': 41, 'D10W': 42, 'D5 1': 43, 'D5NS': 44, 'D5W': 45, 'D5W ': 46, 'Daki': 47, 'Dapt': 48, 'Desm': 49, 'Dexa': 50, 'Dext': 51, 'Diaz': 52, 'Digo': 53, 'Dilt': 54, 'Diov': 55, 'Diph': 56, 'Diso': 57, 'Docu': 58, 'Dola': 59, 'DopA': 60, 'Doxa': 61, 'Enox': 62, 'Epin': 63, 'Eryt': 64, 'Famo': 65, 'Fat ': 66, 'Fent': 67, 'Ferr': 68, 'Fexo': 69, 'Filg': 70, 'Fina': 71, 'Fluc': 72, 'Flud': 73, 'Fluo': 74, 'Flut': 75, 'FoLI': 76, 'Foli': 77, 'Fosp': 78, 'Furo': 79, 'Gaba': 80, 'Gemf': 81, 'Gent':


Processing samples: 100%|██████████| 56/56 [00:00<00:00, 932.10it/s]

Generated 56 samples for task DrugRecommendationMIMIC3





Sample Dataset Statistics:
	- Dataset: mimic3
	- Task: <pyhealth.tasks.drug_recommendation.DrugRecommendationMIMIC3 object at 0x133427b90>
	- Number of samples: 56

First sample structure:
Patient ID: 27333
Number of visits: 1
Sample conditions (first visit): tensor([1, 2, 3, 4, 5])...
Sample procedures (first visit): tensor([1, 2, 3, 4, 5])...
Sample drugs (target): tensor([0., 0., 0., 1., 0., 1., 0., 1., 1., 0.])...


## Split Dataset and Create Data Loaders

Split the dataset by patient to ensure no data leakage between train/validation/test sets.



In [3]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)



Train samples: 38
Validation samples: 4
Test samples: 14


## Initialize SafeDrug Model


In [4]:
from pyhealth.models import SafeDrug

model = SafeDrug(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_layers=1,
    dropout=0.5,
)

print(model)



  import pkg_resources


SafeDrug(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(314, 128)
    (procedures): Embedding(88, 128)
    (drugs_hist): Embedding(166, 128)
  ))
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=201, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=201, out_features=201, bias=True)
    (mpnn_layernorm): LayerNorm((201,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=201, bias=True)
    (los



## Initialize Trainer

We use jaccard similarity, f1 score, pr_auc, ddi score.



In [5]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)



SafeDrug(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(314, 128)
    (procedures): Embedding(88, 128)
    (drugs_hist): Embedding(166, 128)
  ))
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=201, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=201, out_features=201, bias=True)
    (mpnn_layernorm): LayerNorm((201,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=201, bias=True)
    (los

Evaluation: 100%|██████████| 1/1 [00:03<00:00,  3.38s/it]


{'jaccard_samples': 0.06751954513148543, 'f1_samples': 0.11837435592508397, 'pr_auc_samples': 0.06751954513148543, 'ddi_score': 0.0, 'loss': 0.6931472420692444}


## Train the Model

Train the model for a few epochs. I used 5 epochs here. Might need to train for more epochs in prod.



In [6]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="pr_auc_samples",
    optimizer_params={"lr": 1e-4},
)



Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x1334c8b00>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 2/2 [00:04<00:00,  2.36s/it]

--- Train epoch-0, step-2 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 212.49it/s]


--- Eval epoch-0, step-2 ---
jaccard_samples: 0.1530
f1_samples: 0.2510
pr_auc_samples: 0.2577
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.2577) at epoch-0, step-2



Epoch 1 / 5: 100%|██████████| 2/2 [00:00<00:00, 40.41it/s]

--- Train epoch-1, step-4 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 202.72it/s]


--- Eval epoch-1, step-4 ---
jaccard_samples: 0.1530
f1_samples: 0.2510
pr_auc_samples: 0.2659
ddi_score: 0.0000
loss: 0.6930
New best pr_auc_samples score (0.2659) at epoch-1, step-4



Epoch 2 / 5: 100%|██████████| 2/2 [00:00<00:00, 45.36it/s]

--- Train epoch-2, step-6 ---
loss: 0.6930



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 123.52it/s]


--- Eval epoch-2, step-6 ---
jaccard_samples: 0.1530
f1_samples: 0.2510
pr_auc_samples: 0.2776
ddi_score: 0.0000
loss: 0.6930
New best pr_auc_samples score (0.2776) at epoch-2, step-6



Epoch 3 / 5: 100%|██████████| 2/2 [00:00<00:00, 44.84it/s]

--- Train epoch-3, step-8 ---
loss: 0.6929



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 157.26it/s]


--- Eval epoch-3, step-8 ---
jaccard_samples: 0.1530
f1_samples: 0.2510
pr_auc_samples: 0.2989
ddi_score: 0.0000
loss: 0.6929
New best pr_auc_samples score (0.2989) at epoch-3, step-8



Epoch 4 / 5: 100%|██████████| 2/2 [00:00<00:00, 41.89it/s]

--- Train epoch-4, step-10 ---
loss: 0.6928



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 185.05it/s]


--- Eval epoch-4, step-10 ---
jaccard_samples: 0.1530
f1_samples: 0.2510
pr_auc_samples: 0.3378
ddi_score: 0.0000
loss: 0.6928
New best pr_auc_samples score (0.3378) at epoch-4, step-10
Loaded best model


## Evaluate on Test Set

Evaluate the trained model on the test set to see final performance metrics.

I get DDI = 0. That is abnormally low. (In a good way, I guess?)

In [7]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 63.79it/s]


Final test set performance:
{'jaccard_samples': 0.06751954513148543, 'f1_samples': 0.11837435592508397, 'pr_auc_samples': 0.19516980311232707, 'ddi_score': 0.0, 'loss': 0.6927605271339417}

Key Metrics:
  PR-AUC: 0.1952
  F1 Score: 0.1184
  Jaccard: 0.0675
  DDI Rate: 0.0000 (lower is better)
