# SafeDrug Model Training on MIMIC-III Dataset

Train the SafeDrug model for medication recommendation on the MIMIC-III dataset.



In [1]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()



  from .autonotebook import tqdm as notebook_tqdm


No config path provided, using default config
Initializing mimic3 dataset from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III (dev mode: True)
Scanning table: patients from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv
Some column names were converted to lowercase
Scanning table: admissions from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv
Some column names were converted to lowercase
Scanning table: icustays from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv
Some column names were converted to lowercase
Scanning table: 

## Set Drug Recommendation Task

Use the `drug_recommendation_mimic3_fn` task function which creates samples with conditions, procedures, and atc-3 codes (drugs).



In [2]:
from pyhealth.tasks import DrugRecommendationMIMIC3

class FixedDrugRecommendationMIMIC3(DrugRecommendationMIMIC3):
    def __call__(self, patient):
        import polars as pl
        samples = []
        admissions = patient.get_events(event_type="admissions")
        if len(admissions) < 2:
            return []
        
        for i, admission in enumerate(admissions):
            try:
                hadm_id = admission.hadm_id
            except AttributeError:
                hadm_id = getattr(admission, "hadm_id", None)
                if hadm_id is None:
                    hadm_id = admission["hadm_id"] if "hadm_id" in admission else None
                if hadm_id is None:
                    continue
            
            diagnoses_icd = patient.get_events(
                event_type="diagnoses_icd",
                filters=[("hadm_id", "==", hadm_id)],
                return_df=True,
            )
            conditions = (
                diagnoses_icd.select(pl.col("diagnoses_icd/icd9_code"))
                .to_series()
                .to_list()
            )
            
            procedures_icd = patient.get_events(
                event_type="procedures_icd",
                filters=[("hadm_id", "==", hadm_id)],
                return_df=True,
            )
            procedures = (
                procedures_icd.select(pl.col("procedures_icd/icd9_code"))
                .to_series()
                .to_list()
            )
            
            prescriptions = patient.get_events(
                event_type="prescriptions",
                filters=[("hadm_id", "==", hadm_id)],
                return_df=True,
            )
            drugs = (
                prescriptions.select(pl.col("prescriptions/drug")).to_series().to_list()
            )
            drugs = [drug[:4] for drug in drugs if drug]
            
            if len(conditions) * len(procedures) * len(drugs) == 0:
                continue
                
            samples.append({
                "visit_id": hadm_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "drugs_hist": drugs,
            })
        
        if len(samples) < 2:
            return []
        
        samples[0]["conditions"] = [samples[0]["conditions"]]
        samples[0]["procedures"] = [samples[0]["procedures"]]
        samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]]
        
        for i in range(1, len(samples)):
            samples[i]["conditions"] = samples[i - 1]["conditions"] + [samples[i]["conditions"]]
            samples[i]["procedures"] = samples[i - 1]["procedures"] + [samples[i]["procedures"]]
            samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [samples[i]["drugs_hist"]]
        
        for i in range(len(samples)):
            samples[i]["drugs_hist"][i] = []
        
        return samples

task = FixedDrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")



Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...
Generating samples with 4 worker(s)...
Generating samples for DrugRecommendationMIMIC3 with 4 workers


Collecting samples for DrugRecommendationMIMIC3 from 4 workers: 100%|██████████| 1000/1000 [00:00<00:00, 4373.27it/s]

Label drugs vocab: {'*NF*': 0, '1/2 ': 1, 'ACD-': 2, 'Acet': 3, 'Acyc': 4, 'Albu': 5, 'Allo': 6, 'Alpr': 7, 'Alte': 8, 'Alum': 9, 'Ambi': 10, 'Amin': 11, 'Amio': 12, 'Amlo': 13, 'Amox': 14, 'Ampi': 15, 'Arga': 16, 'Arti': 17, 'Asco': 18, 'Aspi': 19, 'Ator': 20, 'Atro': 21, 'Avap': 22, 'Azat': 23, 'Azit': 24, 'Aztr': 25, 'Bacl': 26, 'Beth': 27, 'Bisa': 28, 'Bume': 29, 'Calc': 30, 'Carb': 31, 'Carv': 32, 'Casp': 33, 'Cefa': 34, 'Cefe': 35, 'Ceft': 36, 'Ceph': 37, 'Chlo': 38, 'Cipr': 39, 'Cisa': 40, 'Cita': 41, 'Clin': 42, 'Clon': 43, 'Clop': 44, 'Colc': 45, 'Coli': 46, 'Cosy': 47, 'Cyan': 48, 'D5 1': 49, 'D5W': 50, 'D5W ': 51, 'DOBU': 52, 'Daki': 53, 'Dapt': 54, 'Dexa': 55, 'Dexm': 56, 'Dext': 57, 'Diaz': 58, 'Digo': 59, 'Dilt': 60, 'Diph': 61, 'Diva': 62, 'Docu': 63, 'Dola': 64, 'Done': 65, 'DopA': 66, 'Dorz': 67, 'Doxy': 68, 'Enox': 69, 'Epin': 70, 'Epoe': 71, 'Epti': 72, 'Eryt': 73, 'Euce': 74, 'Exce': 75, 'Ezet': 76, 'Famo': 77, 'Fent': 78, 'Ferr': 79, 'Fexo': 80, 'Fluc': 81, 'Flut':


Processing samples: 100%|██████████| 39/39 [00:00<00:00, 785.35it/s]

Generated 39 samples for task DrugRecommendationMIMIC3





Sample Dataset Statistics:
	- Dataset: mimic3
	- Task: <__main__.FixedDrugRecommendationMIMIC3 object at 0x139193ef0>
	- Number of samples: 39

First sample structure:
Patient ID: 44309
Number of visits: 1
Sample conditions (first visit): tensor([1, 2, 3, 4, 5])...
Sample procedures (first visit): tensor([1, 2, 3, 0, 0])...
Sample drugs (target): tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])...


## Split Dataset and Create Data Loaders

Split the dataset by patient to ensure no data leakage between train/validation/test sets.



In [3]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)



Train samples: 26
Validation samples: 4
Test samples: 9


## Initialize SafeDrug Model


In [4]:
from pyhealth.models import SafeDrug

model = SafeDrug(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    num_layers=1,
    dropout=0.5,
)

print(model)



  import pkg_resources


SafeDrug(
  (embeddings): ModuleDict(
    (conditions): Embedding(255, 128, padding_idx=0)
    (procedures): Embedding(86, 128, padding_idx=0)
    (drugs_hist): Embedding(160, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=221, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=221, out_features=221, bias=True)
    (mpnn_layernorm): LayerNorm((221,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=221, bias=True)
 



## Initialize Trainer

We use jaccard similarity, f1 score, pr_auc, ddi score.



In [5]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)



SafeDrug(
  (embeddings): ModuleDict(
    (conditions): Embedding(255, 128, padding_idx=0)
    (procedures): Embedding(86, 128, padding_idx=0)
    (drugs_hist): Embedding(160, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (safedrug): SafeDrugLayer(
    (bipartite_transform): Linear(in_features=128, out_features=0, bias=True)
    (bipartite_output): Linear(in_features=0, out_features=221, bias=True)
    (mpnn): MolecularGraphNeuralNetwork(
      (embed_fingerprint): Embedding(1, 128)
      (W_fingerprint): ModuleList(
        (0-1): 2 x Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (mpnn_output): Linear(in_features=221, out_features=221, bias=True)
    (mpnn_layernorm): LayerNorm((221,), eps=1e-05, elementwise_affine=True)
    (test): Linear(in_features=128, out_features=221, bias=True)
 

Evaluation: 100%|██████████| 1/1 [00:02<00:00,  2.50s/it]


{'jaccard_samples': 0.13826043237807945, 'f1_samples': 0.21695747906523977, 'pr_auc_samples': 0.13826043237807945, 'ddi_score': 0.0, 'loss': 0.6931472420692444}


## Train the Model

Train the model for a few epochs. I used 5 epochs here. Might need to train for more epochs in prod.



In [6]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="pr_auc_samples",
    optimizer_params={"lr": 1e-4},
)



Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x138da0f80>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 1/1 [00:04<00:00,  4.33s/it]

--- Train epoch-0, step-1 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 170.35it/s]


--- Eval epoch-0, step-1 ---
jaccard_samples: 0.1357
f1_samples: 0.2321
pr_auc_samples: 0.2757
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.2757) at epoch-0, step-1



Epoch 1 / 5: 100%|██████████| 1/1 [00:00<00:00, 24.76it/s]

--- Train epoch-1, step-2 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 178.95it/s]


--- Eval epoch-1, step-2 ---
jaccard_samples: 0.1357
f1_samples: 0.2321
pr_auc_samples: 0.2771
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.2771) at epoch-1, step-2



Epoch 2 / 5: 100%|██████████| 1/1 [00:00<00:00, 36.40it/s]

--- Train epoch-2, step-3 ---
loss: 0.6931



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 137.51it/s]


--- Eval epoch-2, step-3 ---
jaccard_samples: 0.1357
f1_samples: 0.2321
pr_auc_samples: 0.2793
ddi_score: 0.0000
loss: 0.6931
New best pr_auc_samples score (0.2793) at epoch-2, step-3



Epoch 3 / 5: 100%|██████████| 1/1 [00:00<00:00, 23.04it/s]

--- Train epoch-3, step-4 ---
loss: 0.6930



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 94.55it/s]


--- Eval epoch-3, step-4 ---
jaccard_samples: 0.1357
f1_samples: 0.2321
pr_auc_samples: 0.2809
ddi_score: 0.0000
loss: 0.6930
New best pr_auc_samples score (0.2809) at epoch-3, step-4



Epoch 4 / 5: 100%|██████████| 1/1 [00:00<00:00, 37.30it/s]

--- Train epoch-4, step-5 ---
loss: 0.6930



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 146.58it/s]


--- Eval epoch-4, step-5 ---
jaccard_samples: 0.1357
f1_samples: 0.2321
pr_auc_samples: 0.2836
ddi_score: 0.0000
loss: 0.6930
New best pr_auc_samples score (0.2836) at epoch-4, step-5
Loaded best model


## Evaluate on Test Set

Evaluate the trained model on the test set to see final performance metrics.

I get DDI = 0. That is abnormally low. (In a good way, I guess?)

In [7]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 106.38it/s]


Final test set performance:
{'jaccard_samples': 0.13826043237807945, 'f1_samples': 0.21695747906523977, 'pr_auc_samples': 0.31234568776484584, 'ddi_score': 0.0, 'loss': 0.6929892897605896}

Key Metrics:
  PR-AUC: 0.3123
  F1 Score: 0.2170
  Jaccard: 0.1383
  DDI Rate: 0.0000 (lower is better)
