<a href="https://colab.research.google.com/github/sanjayathreya/cs598dl4h-project/blob/main/src/Descriptive-Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Reproducibility summary**

This notebook describes the methods followed by the paper **Chang Lu, Tian Han, and Yue Ning. 2021a. Context- aware health event prediction via transition functions on dynamic disease graphs(Chet). ArXiv, abs/2112.05195**

1.   Claim 1: For heart failure prediction task, Chet outperforms the baseline models based on metrics such as AUC and F1-scores on MIMIC III and MIMIC IV data sets.

2. Claim 2: For diagnosis prediction task, Chet outperforms the baseline models based on MIMIC III and MIMIC IV data sets. The authors compare w-F1 is a weighted sum of F1 scores for all medical codes and R@k which is an average ratio of desired medical codes in top k predictions by the total number of
desired medical codes in each visit.

To verify these claims, we reproduced these results MIMIC III- carevue (Johnson et al., 2022) which excludes overlap of patients in MIMIC IV, and MIMIC IV (Johnson et al., 2023). Additionally, we investigated the effectiveness of the model under different experimental setups such comparing performance of model by changing the number of training epochs, using a different pre-processing
method to extract data, ablation studies that do not include dynamic graph and transition functions









In [3]:
!git clone https://github.com/sanjayathreya/cs598dl4h-project
!mv /content/cs598dl4h-project /content/CHET

fatal: destination path 'cs598dl4h-project' already exists and is not an empty directory.


In [6]:
#@title Copy the files paitients.csv, admissions.csv and diagnoses_icd.csv to mimic3
from google.colab import drive
drive.mount('/content/drive')
# !mkdir /content/CHET/data/mimic3/raw/
# !mkdir /content/CHET/data/mimic4/raw/
!cp -a /content/drive/MyDrive/CHET/data/mimic3/raw/ /content/CHET/data/mimic3/
!cp -a /content/drive/MyDrive/CHET/data/mimic4/raw/ /content/CHET/data/mimic4/

Mounted at /content/drive


In [9]:
%cd /content/CHET/

/content/CHET


In [10]:
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyhealth
  Downloading pyhealth-1.1.3-py2.py3-none-any.whl (113 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.8/113.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Collecting rdkit>=2022.03.4
  Downloading rdkit-2023.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.7/29.7 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit, pyhealth
Successfully installed pyhealth-1.1.3 rdkit-2023.3.1


In [None]:
%cd /content/CHET/src
%pwd

/content/CHET


In [None]:
import pandas as pd
from pyhealth.medcode import CrossMap
from pyhealth.datasets import MIMIC4Dataset,MIMIC3Dataset
from pyhealth.medcode import InnerMap
from pyhealth.datasets.utils import flatten_list
from pyhealth.tokenizer import Tokenizer
import os
import pickle
import numpy as np
from collections import OrderedDict
icd9cm = InnerMap.load("ICD9CM")
data_path = 'data'

In [None]:
!cp /content/CHET/ICD10CM_to_ICD9CM.csv /root/.cache/pyhealth/medcode

In [None]:
def create_parsed_datasets(patient_dict, tablename):
  """Do something.

  Paragraph 1.

  Parameters
  ----------
  argument_name : Type
      description ending with a period.

  Returns
  -------
  Type
      description ending with a period.
  """
  del_pid = {}
  patient_admission = OrderedDict()
  admission_codes = OrderedDict()

  for pid, values in patient_dict.items():
    patient = patient_dict[pid]
    visit_dict = patient.visits
    # we parse patients who have greater than 2 visits
    if(len(visit_dict) >=2):
      admissions = []
      for visit_key, visit_values in visit_dict.items():
        diagnoses = visit_values.get_code_list(table=tablename)
        diagnoses_std = [icd9cm.standardize(code) for code in diagnoses]
        admissions.append({'adm_id': visit_key, 'adm_time': visit_values.encounter_time})
        admission_codes[visit_key] = diagnoses_std

        # if there is a diagnose code with no mapping then drop the patient and 
        counter = 0
        counter = sum([counter+1 for diagnoses in diagnoses_std if diagnoses =='' or diagnoses =='NoDx'])
        if (len(diagnoses) == 0 or counter !=0 ):
          del_pid[pid] = pid
      patient_admission[pid] = sorted(admissions, key=lambda admission: admission['adm_time'])

  for pid in del_pid.keys():
    patient = patient_dict[pid]
    visit_dict = patient.visits
    del patient_admission[pid]
    for visit_key, visit_values in visit_dict.items():
      del admission_codes[visit_key]
  
  return patient_admission,admission_codes

In [None]:
def get_stats( patient_admission, admission_codes):
  """Do something.

  Paragraph 1.

  Parameters
  ----------
  argument_name : Type
      description ending with a period.

  Returns
  -------
  Type
      description ending with a period.
  """
  patient_num = len(patient_admission)
  max_admission_num = max([len(admissions) for admissions in patient_admission.values()])
  avg_admission_num = sum([len(admissions) for admissions in patient_admission.values()]) / patient_num
  max_visit_code_num = max([len(codes) for codes in admission_codes.values()])
  avg_visit_code_num = sum([len(codes) for codes in admission_codes.values()]) / len(admission_codes)
  print('patient num: %d' % patient_num)
  print('max admission num: %d' % max_admission_num)
  print('mean admission num: %.2f' % avg_admission_num)
  print('max code num in an admission: %d' % max_visit_code_num)
  print('mean code num in an admission: %.2f' % avg_visit_code_num)

  return None

In [None]:
def save_parsed_files(parsed_path, **kwargs):
  """Do something.

  Paragraph 1.

  Parameters
  ----------
  argument_name : Type
      description ending with a period.

  Returns
  -------
  Type
      description ending with a period.
  """
  if not os.path.exists(parsed_path):
    os.makedirs(parsed_path)
  for key, value in kwargs.items():
    name = key+'.pkl'
    pickle.dump(value, open(os.path.join(parsed_path, name), 'wb'))
    print(f'saved {key} data ...')

In [None]:
mimic3_ds = MIMIC3Dataset(
    root="data/mimic3/raw",
    tables=["DIAGNOSES_ICD"]
)
dataset = 'mimic3'  # mimic3, eicu, or mimic4
dataset_path = os.path.join(data_path,dataset)
parsed_path = os.path.join(dataset_path, 'parsed')

patient_dict = mimic3_ds.patients
patient_admission,admission_codes = create_parsed_datasets(patient_dict, "DIAGNOSES_ICD")
get_stats(patient_admission,admission_codes)
save_parsed_files (parsed_path, patient_admission = patient_admission , admission_codes = admission_codes )

Parsing PATIENTS and ADMISSIONS: 100%|██████████| 23692/23692 [00:37<00:00, 637.81it/s]
Parsing DIAGNOSES_ICD: 100%|██████████| 26830/26830 [00:03<00:00, 7451.56it/s]
Mapping codes: 100%|██████████| 23692/23692 [00:00<00:00, 118674.16it/s]


patient num: 2169
max admission num: 23
mean admission num: 2.45
max code num in an admission: 39
mean code num in an admission: 10.70
saved patient_admission data ...
saved admission_codes data ...


In [None]:
mimic4_ds = MIMIC4Dataset(
    root="data/mimic4/raw",
    tables=["diagnoses_icd"],
    code_mapping={"ICD10CM": "ICD9CM"},
)
dataset = 'mimic4' 
dataset_path = os.path.join(data_path,dataset)
parsed_path = os.path.join(dataset_path, 'parsed')
patient_dict = mimic4_ds.patients
patient_admission,admission_codes = create_parsed_datasets(patient_dict, "diagnoses_icd")
get_stats(patient_admission,admission_codes)

patient num: 55875
max admission num: 95
mean admission num: 3.69
max code num in an admission: 39
mean code num in an admission: 9.23


In [None]:
def generate_samples(sample_num, seed, patient_admission, admission_codes):
  """Do something.

  Paragraph 1.

  Parameters
  ----------
  argument_name : Type
      description ending with a period.

  Returns
  -------
  Type
      description ending with a period.
  """       
  np.random.seed(seed)
  keys = list(patient_admission.keys())
  selected_pids = np.random.choice(keys, sample_num, False)
  patient_admission_sample = {pid: patient_admission[pid] for pid in selected_pids}
  admission_codes_sample = dict()
  for admissions in patient_admission_sample.values():
      for admission in admissions:
          adm_id = admission['adm_id']
          admission_codes_sample[adm_id] = admission_codes[adm_id]
  return patient_admission_sample, admission_codes_sample

In [None]:
seeds = [6669, 1000, 1050, 2052, 3000]
sample_num = 10000
for idx, seed in enumerate(seeds):
  patient_admission_sample, admission_codes_sample = generate_samples(sample_num, seed, patient_admission, admission_codes)
  parsed_path_sample = os.path.join(parsed_path,str(idx))
  save_parsed_files(parsed_path_sample, patient_admission = patient_admission_sample , admission_codes = admission_codes_sample )
  get_stats(patient_admission_sample,admission_codes_sample)

saving parsed data ...
saved patient_admission data ...
saved admission_codes data ...
patient num: 10000
max admission num: 64
mean admission num: 3.71
max code num in an admission: 39
mean code num in an admission: 9.29
saving parsed data ...
saved patient_admission data ...
saved admission_codes data ...
patient num: 10000
max admission num: 77
mean admission num: 3.68
max code num in an admission: 39
mean code num in an admission: 9.07
saving parsed data ...
saved patient_admission data ...
saved admission_codes data ...
patient num: 10000
max admission num: 94
mean admission num: 3.69
max code num in an admission: 39
mean code num in an admission: 9.23
saving parsed data ...
saved patient_admission data ...
saved admission_codes data ...
patient num: 10000
max admission num: 71
mean admission num: 3.69
max code num in an admission: 39
mean code num in an admission: 9.32
saving parsed data ...
saved patient_admission data ...
saved admission_codes data ...
patient num: 10000
max ad

In [None]:
# for idx, seed in enumerate(seeds):
#   parsed_path_sample = os.path.join(parsed_path,str(idx))
#   patient_admission = pickle.load(open(os.path.join(parsed_path_sample, 'patient_admission.pkl'), 'rb'))
#   admission_codes = pickle.load(open(os.path.join(parsed_path_sample, 'admission_codes.pkl'), 'rb'))

In [None]:
!cp -a /content/CHET/data/mimic3/parsed/ /content/drive/MyDrive/CHET/data/mimic3/parsed/
!cp -a /content/CHET/data/mimic3/parsed/ /content/drive/MyDrive/CHET/data/mimic3/parsed/

In [None]:
# !cp -a /content/CHET/data/mimic4/parsed/ /content/drive/MyDrive/CHET/data/mimic4/parsed/
# !cp -a /content/CHET/data/mimic4/parsed/ /content/drive/MyDrive/CHET/data/mimic4/parsed/

In [None]:
codes = list(admission_codes.values())
codes = list(set(flatten_list(codes)))
tokenizer = Tokenizer(tokens=codes)
code_map = tokenizer.vocabulary.token2idx
admission_codes_encoded = { admission_id: tokenizer.convert_tokens_to_indices(codes) for admission_id, codes in admission_codes.items() }
code_num = len(code_map)
print('There are %d codes' % code_num)

There are 2822 codes
