# Drug Recommendation using MICRON Model on MIMIC-III Dataset

This notebook demonstrates how to use the MICRON model for drug recommendation using the MIMIC-III dataset. The model is implemented using PyHealth 2.0 framework.

MICRON (Medication reCommendation using Recurrent ResIdual Networks) is designed to predict medications based on patient diagnoses and procedures.

## 1. Setup Google Drive and Environment

First, we'll mount Google Drive to access and save our data. We'll also install PyHealth from the forked repository and its dependencies. The notebook uses the latest version of PyHealth from https://github.com/naveenkcb/PyHealth.

In [1]:
# Mount Google Drive
#from google.colab import drive
#drive.mount('/content/drive')

# Install PyHealth from your forked repository
!pip install git+https://github.com/naveenkcb/PyHealth.git
# Install other required packages
!pip install torch scikit-learn pandas numpy tqdm

Collecting git+https://github.com/naveenkcb/PyHealth.git
  Cloning https://github.com/naveenkcb/PyHealth.git to /tmp/pip-req-build-e8mrcm25
  Running command git clone --filter=blob:none --quiet https://github.com/naveenkcb/PyHealth.git /tmp/pip-req-build-e8mrcm25
  Resolved https://github.com/naveenkcb/PyHealth.git to commit bb3da26de8c9747fd67b6842096a9a07b60310eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting mne~=1.10.0 (from pyhealth==2.0a8)
  Downloading mne-1.10.2-py3-none-any.whl.metadata (21 kB)
Collecting numpy~=1.26.4 (from pyhealth==2.0a8)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandarallel~=1.6.5 (from pyhealth==2.0a8)
  Downloading pandarallel



## 2. Import Required Libraries and Setup Configuration

Now we'll import the necessary libraries and set up our configuration for the MICRON model.

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import MICRON
from pyhealth.trainer import Trainer
# from pyhealth.metrics import multilabel_metrics # Removed this import

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)



In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


## 3. Load and Process MIMIC-III Dataset

We'll load the MIMIC-III dataset using PyHealth's built-in dataset loader and prepare it for training. The dataset will be processed to include patient diagnoses, procedures, and medications.

In [11]:
# Configuration
#MIMIC3_PATH = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III"
MIMIC3_PATH = "https://physionet.org/files/mimiciii-demo/1.4/"    #update this dataset path to your environment

# Load MIMIC-III dataset
dataset = MIMIC3Dataset(
    root=MIMIC3_PATH,
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True
)
dataset.stats()

No config path provided, using default config


INFO:pyhealth.datasets.mimic3:No config path provided, using default config


Initializing mimic3 dataset from https://physionet.org/files/mimiciii-demo/1.4/ (dev mode: True)


INFO:pyhealth.datasets.base_dataset:Initializing mimic3 dataset from https://physionet.org/files/mimiciii-demo/1.4/ (dev mode: True)


Scanning table: patients from https://physionet.org/files/mimiciii-demo/1.4/PATIENTS.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: patients from https://physionet.org/files/mimiciii-demo/1.4/PATIENTS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PATIENTS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PATIENTS.csv


Scanning table: admissions from https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: admissions from https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


Scanning table: icustays from https://physionet.org/files/mimiciii-demo/1.4/ICUSTAYS.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: icustays from https://physionet.org/files/mimiciii-demo/1.4/ICUSTAYS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ICUSTAYS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ICUSTAYS.csv


Scanning table: diagnoses_icd from https://physionet.org/files/mimiciii-demo/1.4/DIAGNOSES_ICD.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: diagnoses_icd from https://physionet.org/files/mimiciii-demo/1.4/DIAGNOSES_ICD.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/DIAGNOSES_ICD.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/DIAGNOSES_ICD.csv


Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


Scanning table: procedures_icd from https://physionet.org/files/mimiciii-demo/1.4/PROCEDURES_ICD.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: procedures_icd from https://physionet.org/files/mimiciii-demo/1.4/PROCEDURES_ICD.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PROCEDURES_ICD.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PROCEDURES_ICD.csv


Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


Scanning table: prescriptions from https://physionet.org/files/mimiciii-demo/1.4/PRESCRIPTIONS.csv.gz


INFO:pyhealth.datasets.base_dataset:Scanning table: prescriptions from https://physionet.org/files/mimiciii-demo/1.4/PRESCRIPTIONS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PRESCRIPTIONS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/PRESCRIPTIONS.csv


Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv.gz


Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


INFO:pyhealth.datasets.base_dataset:Original path does not exist. Using alternative: https://physionet.org/files/mimiciii-demo/1.4/ADMISSIONS.csv


Collecting global event dataframe...


INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...


Dev mode enabled: limiting to 1000 patients


INFO:pyhealth.datasets.base_dataset:Dev mode enabled: limiting to 1000 patients


Collected dataframe with shape: (13030, 49)


INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (13030, 49)


Dataset: mimic3
Dev mode: True
Number of patients: 100
Number of events: 13030


## 3. Set Drug Recommendation Task

Use the DrugRecommendationMIMIC3 task function which creates samples with conditions, procedures, and atc-3 codes (drugs).

In [4]:
from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")

Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...


INFO:pyhealth.datasets.base_dataset:Setting task DrugRecommendationMIMIC3 for mimic3 base dataset...


Generating samples with 4 worker(s)...


INFO:pyhealth.datasets.base_dataset:Generating samples with 4 worker(s)...


Generating samples for DrugRecommendationMIMIC3 with 4 workers


INFO:pyhealth.datasets.base_dataset:Generating samples for DrugRecommendationMIMIC3 with 4 workers
Collecting samples for DrugRecommendationMIMIC3 from 4 workers: 100%|██████████| 100/100 [00:00<00:00, 1220.60it/s]

Label drugs vocab: {'*NF*': 0, '0.45': 1, '0.9%': 2, '1/2 ': 3, '5% D': 4, 'AMP': 5, 'Acet': 6, 'Acyc': 7, 'Albu': 8, 'Alen': 9, 'Allo': 10, 'Alpr': 11, 'Alte': 12, 'Alum': 13, 'Amin': 14, 'Ampi': 15, 'Arti': 16, 'Asco': 17, 'Aspi': 18, 'Aten': 19, 'Ator': 20, 'Atov': 21, 'Azit': 22, 'Baci': 23, 'Bacl': 24, 'Bag': 25, 'Bisa': 26, 'Brim': 27, 'Bupi': 28, 'Cabe': 29, 'Calc': 30, 'Caph': 31, 'Caps': 32, 'Capt': 33, 'Carb': 34, 'CefT': 35, 'Cefa': 36, 'Cefe': 37, 'Ceft': 38, 'Cepa': 39, 'Chlo': 40, 'Cipr': 41, 'Cita': 42, 'Clin': 43, 'Clop': 44, 'Clot': 45, 'Coll': 46, 'Cosy': 47, 'Creo': 48, 'Crom': 49, 'Cycl': 50, 'Cypr': 51, 'Cyta': 52, 'D5 1': 53, 'D5NS': 54, 'D5W': 55, 'DOBU': 56, 'DOPa': 57, 'DOXO': 58, 'Daps': 59, 'Dapt': 60, 'Desi': 61, 'Dexa': 62, 'Dext': 63, 'Diaz': 64, 'Dilt': 65, 'Diph': 66, 'Docu': 67, 'Dola': 68, 'Done': 69, 'DopA': 70, 'Dorz': 71, 'Dost': 72, 'Doxy': 73, 'Dulo': 74, 'Enal': 75, 'Enox': 76, 'Epid': 77, 'Epoe': 78, 'Eryt': 79, 'Etop': 80, 'Famo': 81, 'Fat ': 8


INFO:pyhealth.processors.label_processor:Label drugs vocab: {'*NF*': 0, '0.45': 1, '0.9%': 2, '1/2 ': 3, '5% D': 4, 'AMP': 5, 'Acet': 6, 'Acyc': 7, 'Albu': 8, 'Alen': 9, 'Allo': 10, 'Alpr': 11, 'Alte': 12, 'Alum': 13, 'Amin': 14, 'Ampi': 15, 'Arti': 16, 'Asco': 17, 'Aspi': 18, 'Aten': 19, 'Ator': 20, 'Atov': 21, 'Azit': 22, 'Baci': 23, 'Bacl': 24, 'Bag': 25, 'Bisa': 26, 'Brim': 27, 'Bupi': 28, 'Cabe': 29, 'Calc': 30, 'Caph': 31, 'Caps': 32, 'Capt': 33, 'Carb': 34, 'CefT': 35, 'Cefa': 36, 'Cefe': 37, 'Ceft': 38, 'Cepa': 39, 'Chlo': 40, 'Cipr': 41, 'Cita': 42, 'Clin': 43, 'Clop': 44, 'Clot': 45, 'Coll': 46, 'Cosy': 47, 'Creo': 48, 'Crom': 49, 'Cycl': 50, 'Cypr': 51, 'Cyta': 52, 'D5 1': 53, 'D5NS': 54, 'D5W': 55, 'DOBU': 56, 'DOPa': 57, 'DOXO': 58, 'Daps': 59, 'Dapt': 60, 'Desi': 61, 'Dexa': 62, 'Dext': 63, 'Diaz': 64, 'Dilt': 65, 'Diph': 66, 'Docu': 67, 'Dola': 68, 'Done': 69, 'DopA': 70, 'Dorz': 71, 'Dost': 72, 'Doxy': 73, 'Dulo': 74, 'Enal': 75, 'Enox': 76, 'Epid': 77, 'Epoe': 78, 'Er

Generated 36 samples for task DrugRecommendationMIMIC3



INFO:pyhealth.datasets.base_dataset:Generated 36 samples for task DrugRecommendationMIMIC3


Sample Dataset Statistics:
	- Dataset: mimic3
	- Task: <pyhealth.tasks.drug_recommendation.DrugRecommendationMIMIC3 object at 0x7f075ff85f70>
	- Number of samples: 36

First sample structure:
Patient ID: 40124
Number of visits: 1
Sample conditions (first visit): tensor([1, 2, 3, 4, 5])...
Sample procedures (first visit): tensor([1, 0, 0, 0, 0])...
Sample drugs (target): tensor([0., 0., 1., 0., 0., 0., 1., 0., 1., 0.])...


## Split Dataset and Create Data Loaders

In [5]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

Train samples: 27
Validation samples: 3
Test samples: 6


## 4. Initialize and Configure MICRON Model

Now we'll set up the MICRON model with appropriate hyperparameters for drug recommendation.

In [7]:
# Model hyperparameters
model_params = {
    "embedding_dim": 128,
    "hidden_dim": 128,
    "lam": 0.1  # Regularization parameter for reconstruction loss
}

# Initialize MICRON model
model = MICRON(
    dataset=samples,
    **model_params
).to(DEVICE)

print(model)


# Configure trainer
trainer = Trainer(
    model=model,
    #metrics=["jaccard_samples", "f1_samples", "pr_auc_samples"]
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"]
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)




MICRON(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(229, 128)
    (procedures): Embedding(59, 128)
    (drugs_hist): Embedding(198, 128)
  ))
  (micron): MICRONLayer(
    (health_net): Linear(in_features=384, out_features=128, bias=True)
    (prescription_net): Linear(in_features=128, out_features=128, bias=True)
    (fc): Linear(in_features=128, out_features=216, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
MICRON(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(229, 128)
    (procedures): Embedding(59, 128)
    (drugs_hist): Embedding(198, 128)
  ))
  (micron): MICRONLayer(
    (health_net): Linear(in_features=384, out_features=128, bias=True)
    (prescription_net): Linear(in_features=128, out_features=128, bias=True)
    (fc): Linear(in_features=128, out_features=216, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)


INFO:pyhealth.trainer:MICRON(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(229, 128)
    (procedures): Embedding(59, 128)
    (drugs_hist): Embedding(198, 128)
  ))
  (micron): MICRONLayer(
    (health_net): Linear(in_features=384, out_features=128, bias=True)
    (prescription_net): Linear(in_features=128, out_features=128, bias=True)
    (fc): Linear(in_features=128, out_features=216, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)


Metrics: ['jaccard_samples', 'f1_samples', 'pr_auc_samples', 'ddi']


INFO:pyhealth.trainer:Metrics: ['jaccard_samples', 'f1_samples', 'pr_auc_samples', 'ddi']


Device: cuda


INFO:pyhealth.trainer:Device: cuda





INFO:pyhealth.trainer:


Baseline performance before training:


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 204.26it/s]

{'jaccard_samples': 0.14468260340523165, 'f1_samples': 0.25038132615979186, 'pr_auc_samples': 0.16410485822712173, 'ddi_score': 0.0, 'loss': 26.26311683654785}





## 5. Train the Model

Let's train the MICRON model on our processed MIMIC-III dataset.

In [8]:
# Train the model
history = trainer.train(
    train_dataloader, # Pass as positional argument
    val_dataloader,   # Pass as positional argument
    epochs=5,
    monitor="pr_auc_samples"
)

# Save the trained model
#torch.save(model.state_dict(), "/content/drive/MyDrive/micron_model.pt")

Training:


INFO:pyhealth.trainer:Training:


Batch size: 32


INFO:pyhealth.trainer:Batch size: 32


Optimizer: <class 'torch.optim.adam.Adam'>


INFO:pyhealth.trainer:Optimizer: <class 'torch.optim.adam.Adam'>


Optimizer params: {'lr': 0.001}


INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}


Weight decay: 0.0


INFO:pyhealth.trainer:Weight decay: 0.0


Max grad norm: None


INFO:pyhealth.trainer:Max grad norm: None


Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f05e45c9d00>


INFO:pyhealth.trainer:Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f05e45c9d00>


Monitor: pr_auc_samples


INFO:pyhealth.trainer:Monitor: pr_auc_samples


Monitor criterion: max


INFO:pyhealth.trainer:Monitor criterion: max


Epochs: 5


INFO:pyhealth.trainer:Epochs: 5





INFO:pyhealth.trainer:


Epoch 0 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-0, step-1 ---


INFO:pyhealth.trainer:--- Train epoch-0, step-1 ---


loss: 26.7790


INFO:pyhealth.trainer:loss: 26.7790
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 261.21it/s]

--- Eval epoch-0, step-1 ---



INFO:pyhealth.trainer:--- Eval epoch-0, step-1 ---


jaccard_samples: 0.1060


INFO:pyhealth.trainer:jaccard_samples: 0.1060


f1_samples: 0.1902


INFO:pyhealth.trainer:f1_samples: 0.1902


pr_auc_samples: 0.1278


INFO:pyhealth.trainer:pr_auc_samples: 0.1278


ddi_score: 0.0000


INFO:pyhealth.trainer:ddi_score: 0.0000


loss: 12.1593


INFO:pyhealth.trainer:loss: 12.1593


New best pr_auc_samples score (0.1278) at epoch-0, step-1


INFO:pyhealth.trainer:New best pr_auc_samples score (0.1278) at epoch-0, step-1





INFO:pyhealth.trainer:


Epoch 1 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-1, step-2 ---


INFO:pyhealth.trainer:--- Train epoch-1, step-2 ---


loss: 12.1609


INFO:pyhealth.trainer:loss: 12.1609
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 231.73it/s]


--- Eval epoch-1, step-2 ---


INFO:pyhealth.trainer:--- Eval epoch-1, step-2 ---


jaccard_samples: 0.1043


INFO:pyhealth.trainer:jaccard_samples: 0.1043


f1_samples: 0.1881


INFO:pyhealth.trainer:f1_samples: 0.1881


pr_auc_samples: 0.1253


INFO:pyhealth.trainer:pr_auc_samples: 0.1253


ddi_score: 0.0000


INFO:pyhealth.trainer:ddi_score: 0.0000


loss: 8.5477


INFO:pyhealth.trainer:loss: 8.5477





INFO:pyhealth.trainer:


Epoch 2 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-2, step-3 ---


INFO:pyhealth.trainer:--- Train epoch-2, step-3 ---


loss: 8.3976


INFO:pyhealth.trainer:loss: 8.3976
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 235.13it/s]


--- Eval epoch-2, step-3 ---


INFO:pyhealth.trainer:--- Eval epoch-2, step-3 ---


jaccard_samples: 0.0841


INFO:pyhealth.trainer:jaccard_samples: 0.0841


f1_samples: 0.1539


INFO:pyhealth.trainer:f1_samples: 0.1539


pr_auc_samples: 0.1289


INFO:pyhealth.trainer:pr_auc_samples: 0.1289


ddi_score: 0.0000


INFO:pyhealth.trainer:ddi_score: 0.0000


loss: 8.2150


INFO:pyhealth.trainer:loss: 8.2150


New best pr_auc_samples score (0.1289) at epoch-2, step-3


INFO:pyhealth.trainer:New best pr_auc_samples score (0.1289) at epoch-2, step-3





INFO:pyhealth.trainer:


Epoch 3 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-3, step-4 ---


INFO:pyhealth.trainer:--- Train epoch-3, step-4 ---


loss: 8.2572


INFO:pyhealth.trainer:loss: 8.2572
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 225.23it/s]


--- Eval epoch-3, step-4 ---


INFO:pyhealth.trainer:--- Eval epoch-3, step-4 ---


jaccard_samples: 0.1594


INFO:pyhealth.trainer:jaccard_samples: 0.1594


f1_samples: 0.2718


INFO:pyhealth.trainer:f1_samples: 0.2718


pr_auc_samples: 0.1652


INFO:pyhealth.trainer:pr_auc_samples: 0.1652


ddi_score: 0.0000


INFO:pyhealth.trainer:ddi_score: 0.0000


loss: 6.2571


INFO:pyhealth.trainer:loss: 6.2571


New best pr_auc_samples score (0.1652) at epoch-3, step-4


INFO:pyhealth.trainer:New best pr_auc_samples score (0.1652) at epoch-3, step-4





INFO:pyhealth.trainer:


Epoch 4 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-4, step-5 ---


INFO:pyhealth.trainer:--- Train epoch-4, step-5 ---


loss: 6.6966


INFO:pyhealth.trainer:loss: 6.6966
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 266.20it/s]


--- Eval epoch-4, step-5 ---


INFO:pyhealth.trainer:--- Eval epoch-4, step-5 ---


jaccard_samples: 0.1548


INFO:pyhealth.trainer:jaccard_samples: 0.1548


f1_samples: 0.2655


INFO:pyhealth.trainer:f1_samples: 0.2655


pr_auc_samples: 0.1809


INFO:pyhealth.trainer:pr_auc_samples: 0.1809


ddi_score: 0.0000


INFO:pyhealth.trainer:ddi_score: 0.0000


loss: 5.0584


INFO:pyhealth.trainer:loss: 5.0584


New best pr_auc_samples score (0.1809) at epoch-4, step-5


INFO:pyhealth.trainer:New best pr_auc_samples score (0.1809) at epoch-4, step-5


Loaded best model


INFO:pyhealth.trainer:Loaded best model


## 6. Evaluate Model Performance

Finally, let's evaluate our trained model on the test set and visualize the results.

In [10]:
test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 197.05it/s]

Final test set performance:
{'jaccard_samples': 0.20728042582268005, 'f1_samples': 0.33950176113659936, 'pr_auc_samples': 0.27233906976425537, 'ddi_score': 0.0, 'loss': 4.690524578094482}

Key Metrics:
  PR-AUC: 0.2723
  F1 Score: 0.3395
  Jaccard: 0.2073
  DDI Rate: 0.0000 (lower is better)





In [None]:
# before training
# {'jaccard_samples': 0.14468260340523165, 'f1_samples': 0.25038132615979186, 'pr_auc_samples': 0.16410485822712173, 'ddi_score': 0.0, 'loss': 26.26311683654785}

# Post training
# {'jaccard_samples': 0.20728042582268005, 'f1_samples': 0.33950176113659936, 'pr_auc_samples': 0.27233906976425537, 'ddi_score': 0.0, 'loss': 4.690524578094482}

## Conclusion

We have successfully implemented and trained a MICRON model for drug recommendation using the MIMIC-III dataset. The model's performance can be evaluated using the metrics above:

1. DDI score: Shows the drug-drug interaction score indicating the effect of one drug on another. lower is better.
2. Precision: Indicates how many of the predicted drugs were actually correct
3. Recall: Shows how many of the actual drugs were correctly predicted
4. F1 Score: The harmonic mean of precision and recall

The confusion matrix visualization helps us understand where the model performs well and where it might need improvement. The training loss plot shows how the model learned over time.

Next steps could include:
- Hyperparameter tuning to improve performance
- Testing with different model architectures
- Analyzing specific cases where the model performs well or poorly
- Incorporating additional patient features