In [2]:
#Import pyhealth 
import numpy as np
import torch
from pyhealth.datasets import MIMIC3Dataset 
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC3
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
from pyhealth.interpret.methods import CheferRelevance


#Load the data
dataset = MIMIC3Dataset(
    root=r"C:/Users/johnn/PyHealth_Data/Synthetic_MIMIC-III",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=False, 
)
print(dataset.stats())

No config path provided, using default config
Initializing mimic3 dataset from C:/Users/johnn/PyHealth_Data/Synthetic_MIMIC-III (dev mode: False)
Scanning table: patients from C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\PATIENTS.csv.gz
Original path does not exist. Using alternative: C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\PATIENTS.csv
Scanning table: admissions from C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\ADMISSIONS.csv.gz
Original path does not exist. Using alternative: C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\ADMISSIONS.csv
Scanning table: icustays from C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\ICUSTAYS.csv.gz
Original path does not exist. Using alternative: C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\ICUSTAYS.csv
Scanning table: diagnoses_icd from C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\DIAGNOSES_ICD.csv.gz
Original path does not exist. Using alternative: C:\Users\johnn\PyHealth_Data\Synthetic_MIMIC-III\DIAGNOSES_ICD.csv
Joining w



In [3]:
from pyhealth.tasks import MortalityPredictionMIMIC3

# Define the mortality prediction task
task = MortalityPredictionMIMIC3()

samples = dataset.set_task(task)

#Take a look at random sample -> run again for other patients to check their mortality
randomsample = samples.samples[0]
device = randomsample['mortality'].device.type if isinstance(randomsample['mortality'], torch.Tensor) else 'cpu'

print(f"VISIT SAMPLE INSPECTION (Patient: {randomsample['patient_id']})")
print(f"Patient ID:    {randomsample['patient_id']}")

print("\nINPUT(Codes from the current visit):")

# CONDITIONS (Diagnoses)
conditions = randomsample['conditions']
print(f"CONDITIONS (Diagnoses, ICD-9):")
print(f"Total Codes: {len(conditions)}")
print(f"Codes (Token IDs): {conditions}")

# PROCEDURES
procedures = randomsample['procedures']
print(f"PROCEDURES (ICD-9):")
print(f"Total Codes: {len(procedures)}")
print(f"Codes (Token IDs): {procedures}")

# DRUGS
drugs = randomsample['drugs']
print(f"DRUGS (Prescriptions):")
print(f"Total Codes: {len(drugs)}")
print(f"Codes (Token IDs): {drugs}")


print("\nOUTPUT")
mortality_label = randomsample['mortality'].item() if isinstance(randomsample['mortality'], torch.Tensor) else randomsample['mortality']
print(f"MORTALITY: {mortality_label} (0 = Patient survived in next visit, 1 = Mortality in next visit)")


Setting task MortalityPredictionMIMIC3 for mimic3 base dataset...
Generating samples with 1 worker(s)...


Generating samples for MortalityPredictionMIMIC3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 480.06it/s]

Label mortality vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 26/26 [00:00<00:00, 13334.79it/s]

Generated 26 samples for task MortalityPredictionMIMIC3
VISIT SAMPLE INSPECTION (Patient: 10088)
Patient ID:    10088

INPUT(Codes from the current visit):
CONDITIONS (Diagnoses, ICD-9):
Total Codes: 17
Codes (Token IDs): tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])
PROCEDURES (ICD-9):
Total Codes: 5
Codes (Token IDs): tensor([1, 2, 3, 4, 5])
DRUGS (Prescriptions):
Total Codes: 94
Codes (Token IDs): tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 11, 20, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 29, 30, 29, 28,
        29, 31, 32, 28, 30, 33, 33, 28, 28, 34, 20, 23, 28, 29, 28, 32, 28, 28,
        35, 36, 34, 37, 36, 36, 26, 33, 38, 39, 40, 37, 41, 42, 43, 34, 23, 23,
        44, 44, 32, 45, 45, 46, 37, 36, 34, 19, 47, 28, 29, 48, 16, 37, 42, 43,
        34, 23, 28, 34])

OUTPUT
MORTALITY: 0.0 (0 = Patient survived in next visit, 1 = Mortality in next visit)





In [4]:
#Check Mortality statistics
import torch

mortality_count = 0
total_samples = len(samples)
deceased_patient_ids = set() 

for sample in samples.samples:
    mortality_label = sample.get('mortality')
    if isinstance(mortality_label, torch.Tensor):
        mortality_value = mortality_label.item()
    else:
        mortality_value = float(mortality_label)
    mortality_count += mortality_value
    if mortality_value == 1.0:
        deceased_patient_ids.add(sample.get('patient_id'))

print(f"Total Samples Processed: {total_samples}")
print(f"Patients Died: {mortality_count}")
print(f"Patients Survived : {total_samples - mortality_count}")
print(f"Mortality Rate: {mortality_count / total_samples * 100}")
print(f"Patient IDs who died: {list(deceased_patient_ids)}")

Total Samples Processed: 26
Patients Died: 5.0
Patients Survived : 21.0
Mortality Rate: 19.230769230769234
Patient IDs who died: ['40310', '10059', '10124', '10094', '42135']


In [5]:

from pyhealth.datasets import split_by_sample
from pyhealth.datasets import get_dataloader

#Split the data for training, validation, and testing
train_dataset, val_dataset, test_dataset = split_by_sample(dataset = samples, ratios = [0.8, 0.1, 0.1])

#Define dataloaders for training
train_dataloader = get_dataloader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=2, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=2, shuffle=False)

In [6]:
#Initialize the Transformer Model from pyhealth 

model = Transformer( 
    dataset=samples,
    embedding_dim=128,
    heads=2,
    num_layers=2,
    dropout=0.1
)
print(model)

Transformer(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(191, 128, padding_idx=0)
    (procedures): Embedding(46, 128, padding_idx=0)
    (drugs): Embedding(298, 128, padding_idx=0)
  ))
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0-1): 2 x TransformerBlock(
          (attention): MultiHeadedAttention(heads=2, d_model=128, dropout=0.1)
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (activation): GELU(approximate='none')
          )
          (input_sublayer): SublayerConnection(
            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output_sublayer): SublayerConnection(
   

In [7]:
#Initialize Trainer
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy", "f1"]
)

Transformer(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(191, 128, padding_idx=0)
    (procedures): Embedding(46, 128, padding_idx=0)
    (drugs): Embedding(298, 128, padding_idx=0)
  ))
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0-1): 2 x TransformerBlock(
          (attention): MultiHeadedAttention(heads=2, d_model=128, dropout=0.1)
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (activation): GELU(approximate='none')
          )
          (input_sublayer): SublayerConnection(
            (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output_sublayer): SublayerConnection(
   

In [8]:
#Run the training
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="roc_auc",  
    optimizer_params={"lr": 1e-4}, 
)

Training:
Batch size: 2
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x0000018955783370>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5
Patience: None



Epoch 0 / 5: 100%|██████████| 10/10 [00:00<00:00, 25.65it/s]

--- Train epoch-0, step-10 ---
loss: 0.7089



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 145.94it/s]

--- Eval epoch-0, step-10 ---
roc_auc: 0.5000
pr_auc: 0.5000
accuracy: 0.6667
f1: 0.0000
loss: 0.5723
New best roc_auc score (0.5000) at epoch-0, step-10




Epoch 1 / 5: 100%|██████████| 10/10 [00:00<00:00, 33.33it/s]

--- Train epoch-1, step-20 ---
loss: 0.4792



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 177.39it/s]

--- Eval epoch-1, step-20 ---
roc_auc: 0.5000
pr_auc: 0.5000
accuracy: 0.6667
f1: 0.0000
loss: 0.5503




Epoch 2 / 5: 100%|██████████| 10/10 [00:00<00:00, 30.93it/s]

--- Train epoch-2, step-30 ---
loss: 0.3979



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 194.19it/s]

--- Eval epoch-2, step-30 ---
roc_auc: 0.5000
pr_auc: 0.5000
accuracy: 0.6667
f1: 0.0000
loss: 0.5531




Epoch 3 / 5: 100%|██████████| 10/10 [00:00<00:00, 27.30it/s]

--- Train epoch-3, step-40 ---
loss: 0.2630



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 214.60it/s]

--- Eval epoch-3, step-40 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.6667
f1: 0.0000
loss: 0.5584
New best roc_auc score (1.0000) at epoch-3, step-40




Epoch 4 / 5: 100%|██████████| 10/10 [00:00<00:00, 29.52it/s]

--- Train epoch-4, step-50 ---
loss: 0.2075



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 150.89it/s]

--- Eval epoch-4, step-50 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.6667
f1: 0.0000
loss: 0.5536
Loaded best model





In [9]:
#Evaluate the results
results = trainer.evaluate(test_dataloader)
print("Test Results:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")

Evaluation: 100%|██████████| 2/2 [00:00<00:00, 145.35it/s]

Test Results:
roc_auc: nan
pr_auc: 0.0000
accuracy: 1.0000
f1: 0.0000
loss: 0.1447



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [11]:
# Interpreting model predictions using Pyhealth's Chefer and IntegratedGradients
# Our original ablation from the reproduced paper introduces the
# AttInGrad-Weighted Average (AttInWAvg), defined as:
# 
# AttInWAvg_j = λ · A_j  +  (1 − λ) · NormInputXGrad_j

# Where:
#   A_j = attention weight for token j
#   NormInputXGrad_j = normalized input × gradient attribution
#   λ ∈ [0, 1] = interpolation coefficient
#
# In this example, we replace the original attention weights A_j
# with Chefer's relevance scores obtained from PyHealth's
# CheferRelevance method, and we use PyHealth's IntegratedGradients
# method to compute the gradient-based attributions. We normalize them both also for consistency
#
# The final weighted attribution becomes:
#
#       AttInWAvg_j = λ · NormChefer_j  +  (1 − λ) · NormInputXGrad_j
#

from pyhealth.interpret.methods import IntegratedGradients

#1. Chefer
chefer = CheferRelevance(model)
batch = next(iter(val_dataloader))

chefer_scores = chefer.attribute(
    conditions=batch["conditions"],
    procedures=batch["procedures"],
    drugs=batch["drugs"],
    mortality=batch["mortality"],
)


#2. Integrated Gradients
ig = IntegratedGradients(model)

ig_scores = ig.attribute(
    conditions=batch["conditions"],
    procedures=batch["procedures"],
    drugs=batch["drugs"],
    mortality=batch["mortality"],
    steps=20,  
)

#3. Normalize chefer and IG scores
norm_chefer = {}
norm_ig = {}

for key in chefer_scores:
    c = chefer_scores[key][0]
    g = ig_scores[key][0]

    norm_chefer[key] = c / (c.abs().sum() + 1e-9)
    norm_ig[key] = g / (g.abs().sum() + 1e-9)

#4. Weighted average (AttInWAvg)
lambda_val = 0.5
AttInWAvg = {}

for key in chefer_scores:
    AttInWAvg[key] = lambda_val * norm_chefer[key] + (1 - lambda_val) * norm_ig[key]
    
# 5. Print top tokens for all 3 methods
print("Feature keys:", chefer_scores.keys())
for key in chefer_scores:

    c = norm_chefer[key]
    top_c = torch.topk(c, k=min(5, c.size(0)))
    print(f"\nTop 5 Chefer Relevance for {key}:")
    print("Token indices:", top_c.indices.tolist())
    print("Scores:", [f"{v:.4f}" for v in top_c.values.tolist()])

   
    g = norm_ig[key]
    top_g = torch.topk(g, k=min(5, g.size(0)))
    print(f"\nTop 5 Normalized IG for {key}:")
    print("Token indices:", top_g.indices.tolist())
    print("Scores:", [f"{v:.4f}" for v in top_g.values.tolist()])


    w = AttInWAvg[key]
    top_w = torch.topk(w, k=min(5, w.size(0)))
    print(f"\nTop 5 AttInWAvg for {key}:")
    print("Token indices:", top_w.indices.tolist())
    print("Scores:", [f"{v:.4f}" for v in top_w.values.tolist()])






Feature keys: dict_keys(['conditions', 'procedures', 'drugs'])

Top 5 Chefer Relevance for conditions:
Token indices: [0, 6, 9, 5, 7]
Scores: ['0.9340', '0.0160', '0.0147', '0.0096', '0.0084']

Top 5 Normalized IG for conditions:
Token indices: [0, 10, 8, 6, 3]
Scores: ['0.3225', '0.0880', '0.0277', '0.0016', '-0.0121']

Top 5 AttInWAvg for conditions:
Token indices: [0, 10, 8, 6, 9]
Scores: ['0.6283', '0.0441', '0.0144', '0.0088', '-0.0012']

Top 5 Chefer Relevance for procedures:
Token indices: [0, 1, 2, 3]
Scores: ['0.9133', '0.0867', '0.0000', '0.0000']

Top 5 Normalized IG for procedures:
Token indices: [0, 2, 3, 1]
Scores: ['0.7110', '-0.0033', '-0.0204', '-0.2653']

Top 5 AttInWAvg for procedures:
Token indices: [0, 2, 3, 1]
Scores: ['0.8121', '-0.0016', '-0.0102', '-0.0893']

Top 5 Chefer Relevance for drugs:
Token indices: [0, 7, 78, 51, 32]
Scores: ['0.9511', '0.0022', '0.0022', '0.0018', '0.0013']

Top 5 Normalized IG for drugs:
Token indices: [19, 17, 63, 68, 67]
Scores: ['