In [2]:
import pyhealth
import pickle
import os

## dataset lookups

In [3]:
"""
where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes
"""

DATA_PATH = "./"
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

assert len(pids) == len(vids) == len(hfs) == len(seqs) == 1000
assert len(types) == 619

In [4]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


In [5]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## Step1 & 2: dataset processing and task definition

In [6]:
"""
With this user processed data, we just need to wrap up the data to fit
the pyhealth format.

Task definition: we will use all the visits from the patients as the features,
    then, the labels are given b the hfs list.
"""
samples = []
for pid, vid, hf, seq in zip(pids, vids, hfs, seqs):
    samples.append(
        {
            'patient_id': pid,
            'visit_id': vid[-1],
            'label': hf,
            'diagnoses': [[rtypes[v] for v in visit] for visit in seq],
        }
    )
    

In [7]:
# load into the dataset
from pyhealth.datasets import SampleEHRDataset
train_dataset = SampleEHRDataset(samples, code_vocs=None)

In [14]:
train_dataset.samples[0]

{'patient_id': 89571,
 'visit_id': 1,
 'label': 1,
 'diagnoses': [['DIAG_250',
   'DIAG_285',
   'DIAG_682',
   'DIAG_730',
   'DIAG_531',
   'DIAG_996',
   'DIAG_287',
   'DIAG_276',
   'DIAG_E878',
   'DIAG_V45',
   'DIAG_996'],
  ['DIAG_250',
   'DIAG_276',
   'DIAG_285',
   'DIAG_998',
   'DIAG_996',
   'DIAG_078',
   'DIAG_336',
   'DIAG_E878',
   'DIAG_401',
   'DIAG_205']]}

In [8]:
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import split_by_patient, get_dataloader

# dataset split by patient id
train_ds, val_ds, test_ds = split_by_patient(train_dataset, [0.8, 0.2, 0])

# obtain train/val/test dataloader, they are <torch.data.DataLoader> object
train_loader = get_dataloader(train_ds, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=64, shuffle=False)

## Step3: initialize the LSTM model

In [9]:
from pyhealth.models import RNN

model = RNN(
    dataset=dataset,
    feature_keys=["diagnoses"],
    label_key="label",
    mode="binary",
    rnn_type="LSTM",
    num_layers=2,
    embedding_dim=64,
    hidden_dim=64,
)

## Step4: model training

In [10]:
from pyhealth.trainer import Trainer

# use our Trainer to train the model

trainer = Trainer(
    model=model,
    metrics=["pr_auc", "roc_auc", "f1"]
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=20,
    monitor="roc_auc",
)

RNN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(601, 64, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (diagnoses): RNNLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (rnn): LSTM(64, 64, num_layers=2, batch_first=True, dropout=0.5)
    )
  )
  (fc): Linear(in_features=64, out_features=1, bias=True)
)
Metrics: ['pr_auc', 'roc_auc', 'f1']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f933d6add90>
Monitor: roc_auc
Monitor criterion: max
Epochs: 20



Epoch 0 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-0, step-13 ---
loss: 0.6889


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 111.64it/s]

--- Eval epoch-0, step-13 ---
pr_auc: 0.6721
roc_auc: 0.6805
f1: 0.6993
loss: 0.6820
New best roc_auc score (0.6805) at epoch-0, step-13






Epoch 1 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-1, step-26 ---
loss: 0.6778


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 196.64it/s]

--- Eval epoch-1, step-26 ---
pr_auc: 0.6903
roc_auc: 0.7080
f1: 0.7114
loss: 0.6673
New best roc_auc score (0.7080) at epoch-1, step-26






Epoch 2 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-2, step-39 ---
loss: 0.6681


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 238.66it/s]

--- Eval epoch-2, step-39 ---
pr_auc: 0.6957
roc_auc: 0.7194
f1: 0.7241
loss: 0.6458
New best roc_auc score (0.7194) at epoch-2, step-39






Epoch 3 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-3, step-52 ---
loss: 0.6522


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 234.29it/s]

--- Eval epoch-3, step-52 ---
pr_auc: 0.7096
roc_auc: 0.7310
f1: 0.7365
loss: 0.6176
New best roc_auc score (0.7310) at epoch-3, step-52






Epoch 4 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-4, step-65 ---
loss: 0.6313


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 231.39it/s]

--- Eval epoch-4, step-65 ---
pr_auc: 0.7104
roc_auc: 0.7323
f1: 0.7576
loss: 0.5919
New best roc_auc score (0.7323) at epoch-4, step-65






Epoch 5 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-5, step-78 ---
loss: 0.6152


Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 51.77it/s]

--- Eval epoch-5, step-78 ---
pr_auc: 0.7206
roc_auc: 0.7361
f1: 0.7360
loss: 0.5673
New best roc_auc score (0.7361) at epoch-5, step-78






Epoch 6 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-6, step-91 ---
loss: 0.5898


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 121.60it/s]

--- Eval epoch-6, step-91 ---
pr_auc: 0.7347
roc_auc: 0.7483
f1: 0.7303
loss: 0.5458
New best roc_auc score (0.7483) at epoch-6, step-91






Epoch 7 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-7, step-104 ---
loss: 0.5629


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 236.18it/s]

--- Eval epoch-7, step-104 ---
pr_auc: 0.7378
roc_auc: 0.7527
f1: 0.7364
loss: 0.5317
New best roc_auc score (0.7527) at epoch-7, step-104






Epoch 8 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-8, step-117 ---
loss: 0.5589


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 235.89it/s]

--- Eval epoch-8, step-117 ---
pr_auc: 0.7656
roc_auc: 0.7673
f1: 0.7404
loss: 0.5101
New best roc_auc score (0.7673) at epoch-8, step-117






Epoch 9 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-9, step-130 ---
loss: 0.5472


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 131.29it/s]

--- Eval epoch-9, step-130 ---
pr_auc: 0.7745
roc_auc: 0.7733
f1: 0.7552
loss: 0.5079
New best roc_auc score (0.7733) at epoch-9, step-130






Epoch 10 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-10, step-143 ---
loss: 0.5463


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 102.12it/s]

--- Eval epoch-10, step-143 ---
pr_auc: 0.7879
roc_auc: 0.7839
f1: 0.7642
loss: 0.5002
New best roc_auc score (0.7839) at epoch-10, step-143






Epoch 11 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-11, step-156 ---
loss: 0.5498


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 232.88it/s]

--- Eval epoch-11, step-156 ---
pr_auc: 0.7941
roc_auc: 0.7882
f1: 0.7686
loss: 0.4904
New best roc_auc score (0.7882) at epoch-11, step-156






Epoch 12 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-12, step-169 ---
loss: 0.5162


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 238.60it/s]

--- Eval epoch-12, step-169 ---
pr_auc: 0.7973
roc_auc: 0.7882
f1: 0.7686
loss: 0.4855






Epoch 13 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-13, step-182 ---
loss: 0.5113


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 239.53it/s]

--- Eval epoch-13, step-182 ---
pr_auc: 0.7970
roc_auc: 0.7880
f1: 0.7635
loss: 0.4821






Epoch 14 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-14, step-195 ---
loss: 0.5085


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 177.25it/s]

--- Eval epoch-14, step-195 ---
pr_auc: 0.7967
roc_auc: 0.7898
f1: 0.7737
loss: 0.4798
New best roc_auc score (0.7898) at epoch-14, step-195






Epoch 15 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-15, step-208 ---
loss: 0.5254


Evaluation: 100%|████████████████████████| 4/4 [00:00<00:00, 238.35it/s]

--- Eval epoch-15, step-208 ---
pr_auc: 0.7895
roc_auc: 0.7894
f1: 0.7692
loss: 0.4850






Epoch 16 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-16, step-221 ---
loss: 0.4883


Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 52.19it/s]

--- Eval epoch-16, step-221 ---
pr_auc: 0.7896
roc_auc: 0.7889
f1: 0.7718
loss: 0.4810






Epoch 17 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-17, step-234 ---
loss: 0.4866


Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 51.81it/s]

--- Eval epoch-17, step-234 ---
pr_auc: 0.7851
roc_auc: 0.7879
f1: 0.7755
loss: 0.4894






Epoch 18 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-18, step-247 ---
loss: 0.5024


Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 72.17it/s]

--- Eval epoch-18, step-247 ---
pr_auc: 0.7900
roc_auc: 0.7911
f1: 0.7782
loss: 0.4830
New best roc_auc score (0.7911) at epoch-18, step-247






Epoch 19 / 20:   0%|          | 0/13 [00:00<?, ?it/s]

--- Train epoch-19, step-260 ---
loss: 0.4758


Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 51.86it/s]

--- Eval epoch-19, step-260 ---
pr_auc: 0.7912
roc_auc: 0.7926
f1: 0.7737
loss: 0.4866
New best roc_auc score (0.7926) at epoch-19, step-260
Loaded best model





## Step5: model evaluation

In [11]:
DATA_PATH = "./"
pids = pickle.load(open(os.path.join(DATA_PATH,'test/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'test/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'test/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'test/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'test/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'test/rtypes.pkl'), 'rb'))

In [12]:
test_samples = []
for pid, vid, hf, seq in zip(pids, vids, hfs, seqs):
    test_samples.append(
        {
            'patient_id': pid,
            'visit_id': vid[-1],
            'label': hf,
            'diagnoses': [[rtypes[v] for v in visit] for visit in seq],
        }
    )
    
test_dataset = SampleEHRDataset(test_samples, code_vocs=None)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)

In [13]:
result = trainer.evaluate(test_loader)
print (result)

Evaluation: 100%|█████████████████████████| 4/4 [00:00<00:00, 61.53it/s]

{'pr_auc': 0.7694151513845636, 'roc_auc': 0.7851632456261194, 'f1': 0.7318840579710146, 'loss': 0.6059411615133286}



