In [1]:
import pyhealth
import pickle
import os

## dataset lookups

In [2]:
"""
where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes
"""

DATA_PATH = "./"
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

assert len(pids) == len(vids) == len(hfs) == len(seqs) == 1000
assert len(types) == 619

In [3]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


In [4]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## Step1 & 2: dataset processing and task definition

In [5]:
"""
With this user processed data, we just need to wrap up the data to fit
the pyhealth format.

Task definition: we will use all the visits from the patients as the features,
    then, the labels are given b the hfs list.
"""
samples = []
for pid, vid, hf, seq in zip(pids, vids, hfs, seqs):
    samples.append(
        {
            'patient_id': pid,
            'visit_id': vid[-1],
            'label': hf,
            'diagnoses': [[rtypes[v] for v in visit] for visit in seq],
        }
    )
    

In [6]:
# load into the dataset
from pyhealth.datasets import SampleEHRDataset
train_dataset = SampleEHRDataset(samples, code_vocs=None)

In [7]:
train_dataset.samples[0]

{'patient_id': 89571,
 'visit_id': 1,
 'label': 1,
 'diagnoses': [['DIAG_250',
   'DIAG_285',
   'DIAG_682',
   'DIAG_730',
   'DIAG_531',
   'DIAG_996',
   'DIAG_287',
   'DIAG_276',
   'DIAG_E878',
   'DIAG_V45',
   'DIAG_996'],
  ['DIAG_250',
   'DIAG_276',
   'DIAG_285',
   'DIAG_998',
   'DIAG_996',
   'DIAG_078',
   'DIAG_336',
   'DIAG_E878',
   'DIAG_401',
   'DIAG_205']]}

In [8]:
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import split_by_patient, get_dataloader

# dataset split by patient id
train_ds, val_ds, test_ds = split_by_patient(train_dataset, [0.8, 0.2, 0])

# obtain train/val/test dataloader, they are <torch.data.DataLoader> object
train_loader = get_dataloader(train_ds, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=64, shuffle=False)

## Step3: initialize the RETAIN model

In [18]:
from pyhealth.models import RETAIN

model = RETAIN(
    dataset=train_dataset,
    feature_keys=["diagnoses"],
    label_key="label",
    mode="binary",
    embedding_dim=128,
)

## Step4: model training

In [19]:
from pyhealth.trainer import Trainer

# use our Trainer to train the model

trainer = Trainer(
    model=model,
    metrics=["pr_auc", "roc_auc", "f1"]
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=20,
    monitor="roc_auc",
)

RETAIN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(601, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (diagnoses): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
)
Metrics: ['pr_auc', 'roc_auc', 'f1']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7fbe2f844520>
Monitor: roc_auc
Monitor criterion: max
Epochs: 20



Epoch 0 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-0, step-13 ---
loss: 0.6890


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 130.91it/s]

--- Eval epoch-0, step-13 ---
pr_auc: 0.6724
roc_auc: 0.6552
f1: 0.7080
loss: 0.6754
New best roc_auc score (0.6552) at epoch-0, step-13






Epoch 1 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-1, step-26 ---
loss: 0.6443


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 85.25it/s]

--- Eval epoch-1, step-26 ---
pr_auc: 0.7390
roc_auc: 0.7326
f1: 0.7063
loss: 0.6501
New best roc_auc score (0.7326) at epoch-1, step-26






Epoch 2 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-2, step-39 ---
loss: 0.6090


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 89.75it/s]

--- Eval epoch-2, step-39 ---
pr_auc: 0.7614
roc_auc: 0.7647
f1: 0.7300
loss: 0.6397
New best roc_auc score (0.7647) at epoch-2, step-39






Epoch 3 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-3, step-52 ---
loss: 0.5938


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 124.18it/s]

--- Eval epoch-3, step-52 ---
pr_auc: 0.7655
roc_auc: 0.7755
f1: 0.7603
loss: 0.5997
New best roc_auc score (0.7755) at epoch-3, step-52






Epoch 4 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-4, step-65 ---
loss: 0.5551


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 144.59it/s]

--- Eval epoch-4, step-65 ---
pr_auc: 0.7727
roc_auc: 0.7817
f1: 0.7932
loss: 0.5812
New best roc_auc score (0.7817) at epoch-4, step-65






Epoch 5 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-5, step-78 ---
loss: 0.5247


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 105.18it/s]

--- Eval epoch-5, step-78 ---
pr_auc: 0.7756
roc_auc: 0.7847
f1: 0.7860
loss: 0.5772
New best roc_auc score (0.7847) at epoch-5, step-78






Epoch 6 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-6, step-91 ---
loss: 0.4954


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 179.29it/s]

--- Eval epoch-6, step-91 ---
pr_auc: 0.7808
roc_auc: 0.7949
f1: 0.8018
loss: 0.5632
New best roc_auc score (0.7949) at epoch-6, step-91






Epoch 7 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-7, step-104 ---
loss: 0.4877


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 178.37it/s]

--- Eval epoch-7, step-104 ---
pr_auc: 0.7768
roc_auc: 0.7962
f1: 0.7822
loss: 0.5572
New best roc_auc score (0.7962) at epoch-7, step-104






Epoch 8 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-8, step-117 ---
loss: 0.4586


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 169.00it/s]

--- Eval epoch-8, step-117 ---
pr_auc: 0.7647
roc_auc: 0.7842
f1: 0.7580
loss: 0.5577






Epoch 9 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-9, step-130 ---
loss: 0.4280


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 156.07it/s]

--- Eval epoch-9, step-130 ---
pr_auc: 0.7607
roc_auc: 0.7788
f1: 0.7580
loss: 0.5780






Epoch 10 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-10, step-143 ---
loss: 0.3954


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 179.59it/s]

--- Eval epoch-10, step-143 ---
pr_auc: 0.7648
roc_auc: 0.7750
f1: 0.7545
loss: 0.5786






Epoch 11 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-11, step-156 ---
loss: 0.3829


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 179.32it/s]

--- Eval epoch-11, step-156 ---
pr_auc: 0.7759
roc_auc: 0.7856
f1: 0.7523
loss: 0.5592






Epoch 12 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-12, step-169 ---
loss: 0.3781


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 167.95it/s]

--- Eval epoch-12, step-169 ---
pr_auc: 0.7788
roc_auc: 0.7907
f1: 0.7489
loss: 0.5554






Epoch 13 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-13, step-182 ---
loss: 0.3242


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 168.76it/s]

--- Eval epoch-13, step-182 ---
pr_auc: 0.7844
roc_auc: 0.7903
f1: 0.7373
loss: 0.5529






Epoch 14 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-14, step-195 ---
loss: 0.3310


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 177.04it/s]

--- Eval epoch-14, step-195 ---
pr_auc: 0.7806
roc_auc: 0.7952
f1: 0.7290
loss: 0.5647






Epoch 15 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-15, step-208 ---
loss: 0.2925


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 178.57it/s]

--- Eval epoch-15, step-208 ---
pr_auc: 0.7746
roc_auc: 0.7932
f1: 0.7489
loss: 0.5785






Epoch 16 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-16, step-221 ---
loss: 0.2872


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 177.03it/s]

--- Eval epoch-16, step-221 ---
pr_auc: 0.7843
roc_auc: 0.7942
f1: 0.7421
loss: 0.6026






Epoch 17 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-17, step-234 ---
loss: 0.2841


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 179.89it/s]

--- Eval epoch-17, step-234 ---
pr_auc: 0.7776
roc_auc: 0.7883
f1: 0.7364
loss: 0.6263






Epoch 18 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-18, step-247 ---
loss: 0.2695


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 181.40it/s]

--- Eval epoch-18, step-247 ---
pr_auc: 0.7726
roc_auc: 0.7843
f1: 0.7453
loss: 0.6259






Epoch 19 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-19, step-260 ---
loss: 0.2427


Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 172.74it/s]

--- Eval epoch-19, step-260 ---
pr_auc: 0.7767
roc_auc: 0.7861
f1: 0.7407
loss: 0.6537
Loaded best model





## Step5: model evaluation

In [20]:
DATA_PATH = "./"
pids = pickle.load(open(os.path.join(DATA_PATH,'test/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'test/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'test/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'test/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'test/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'test/rtypes.pkl'), 'rb'))

In [21]:
test_samples = []
for pid, vid, hf, seq in zip(pids, vids, hfs, seqs):
    test_samples.append(
        {
            'patient_id': pid,
            'visit_id': vid[-1],
            'label': hf,
            'diagnoses': [[rtypes[v] for v in visit] for visit in seq],
        }
    )
    
test_dataset = SampleEHRDataset(test_samples, code_vocs=None)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)

In [22]:
result = trainer.evaluate(test_loader)
print (result)

Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 133.42it/s]

{'pr_auc': 0.7582368833149129, 'roc_auc': 0.7667034026725444, 'f1': 0.7500000000000001, 'loss': 0.595910519361496}



