In [155]:
import pandas as pd
import os
import pickle
from tqdm import tqdm
import numpy as np

In [2]:
DATA_PATH = 'mimic-iii'

In [3]:
DIAGNOSES_ICD = pd.read_csv(os.path.join(DATA_PATH, 'DIAGNOSES_ICD.csv'))
print('Length of DIAGNOSES_ICD: {}'.format(len(DIAGNOSES_ICD)))

Length of DIAGNOSES_ICD: 651047


In [6]:
# find all unique pid
# count maxium seq for a patient

In [7]:
pids = list(set(DIAGNOSES_ICD['SUBJECT_ID']))
print('Number of unique patients: {}'.format(len(pids)))

Number of unique patients: 46520


In [19]:
DIAGNOSES_ICD.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,1297,109,172335,1.0,40301
1,1298,109,172335,2.0,486
2,1299,109,172335,3.0,58281
3,1300,109,172335,4.0,5855
4,1301,109,172335,5.0,4254


In [33]:
D_ICD_DIAGNOSES = pd.read_csv(os.path.join(DATA_PATH, 'D_ICD_DIAGNOSES.csv'))
print('Length of D_ICD_DIAGNOSES: {}'.format(len(D_ICD_DIAGNOSES)))
D_ICD_DIAGNOSES.head(10)

Length of D_ICD_DIAGNOSES: 14567


Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,174,1166,TB pneumonia-oth test,"Tuberculous pneumonia [any form], tubercle bac..."
1,175,1170,TB pneumothorax-unspec,"Tuberculous pneumothorax, unspecified"
2,176,1171,TB pneumothorax-no exam,"Tuberculous pneumothorax, bacteriological or h..."
3,177,1172,TB pneumothorx-exam unkn,"Tuberculous pneumothorax, bacteriological or h..."
4,178,1173,TB pneumothorax-micro dx,"Tuberculous pneumothorax, tubercle bacilli fou..."
5,179,1174,TB pneumothorax-cult dx,"Tuberculous pneumothorax, tubercle bacilli not..."
6,180,1175,TB pneumothorax-histo dx,"Tuberculous pneumothorax, tubercle bacilli not..."
7,181,1176,TB pneumothorax-oth test,"Tuberculous pneumothorax, tubercle bacilli not..."
8,182,1180,Pulmonary TB NEC-unspec,"Other specified pulmonary tuberculosis, unspec..."
9,183,1181,Pulmonary TB NEC-no exam,"Other specified pulmonary tuberculosis, bacter..."


## Preprocessing

In [15]:
# This funciton is to find out the maximum number of visits?
p_seqs = []
visit_counts = []
for p_id in tqdm(pids):
    patient = DIAGNOSES_ICD.loc[DIAGNOSES_ICD['SUBJECT_ID'] == p_id]
    visit_counts.append(len(patient))
    
    # iterate through each visit (hospital - HADM_ID)
    #for visit_id in patient.HADM_ID.unique():
    #    visit = patient.loc[patient['HADM_ID'] == visit_id]
    #    diagnosis_length = len(visit)

100%|███████████████████████████████| 46520/46520 [00:24<00:00, 1868.91it/s]


In [23]:
print('Maximum visit of a patient: {}'.format(max(visit_counts)))
print('Average visit per patient: {}'.format(sum(visit_counts)/len(visit_counts)))

Maximum visit of a patient: 540
Average visit per patient: 13.994991401547722


In [21]:
diagnosis_codes = {} # key: pid, value: array of visits, each visit contains ICD-9 codes
for p_id in tqdm(pids):
    patient = DIAGNOSES_ICD.loc[DIAGNOSES_ICD['SUBJECT_ID'] == p_id]
    curr = []
    for visit_id in patient.HADM_ID.unique():
        visit = patient.loc[patient['HADM_ID'] == visit_id]
        curr.append(visit.ICD9_CODE.tolist())
    
    diagnosis_codes[p_id] = curr

100%|███████████████████████████████| 46520/46520 [00:33<00:00, 1369.09it/s]


In [39]:
len(diagnosis_codes)

46520

#### Find out the top diagnosis codes

In [26]:
# per patient base
# count diagnosis codes with 3 or more visits
diag_codes = {}
for p_id in tqdm(pids):
    if len(diagnosis_codes[p_id]) < 3:
        continue
    
    # count codes
    for visit in diagnosis_codes[p_id][3:]:
        for code in visit:
            if code not in diag_codes:
                diag_codes[code] = 0
            
            diag_codes[code] += 1

100%|████████████████████████████| 46520/46520 [00:00<00:00, 2184347.47it/s]


In [129]:
sorted_code_freq = sorted(diag_codes.items(), key=lambda item: item[1], reverse=True)
sorted_code_freq[:10]

[('4280', 1067),
 ('4019', 728),
 ('5849', 641),
 ('42731', 629),
 ('41401', 485),
 ('51881', 458),
 ('25000', 450),
 ('5856', 437),
 ('5990', 427),
 ('40391', 416)]

### Create Mapping between ICD9_CODE and titles

In [34]:
D_ICD_DIAGNOSES.head()

Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,174,1166,TB pneumonia-oth test,"Tuberculous pneumonia [any form], tubercle bac..."
1,175,1170,TB pneumothorax-unspec,"Tuberculous pneumothorax, unspecified"
2,176,1171,TB pneumothorax-no exam,"Tuberculous pneumothorax, bacteriological or h..."
3,177,1172,TB pneumothorx-exam unkn,"Tuberculous pneumothorax, bacteriological or h..."
4,178,1173,TB pneumothorax-micro dx,"Tuberculous pneumothorax, tubercle bacilli fou..."


In [55]:
for code, freq in sorted_code_freq[:10]:
    print(D_ICD_DIAGNOSES.loc[D_ICD_DIAGNOSES['ICD9_CODE'] == code][['ICD9_CODE', 'SHORT_TITLE']])

     ICD9_CODE SHORT_TITLE
4472      4280     CHF NOS
     ICD9_CODE       SHORT_TITLE
4303      4019  Hypertension NOS
     ICD9_CODE               SHORT_TITLE
5906      5849  Acute kidney failure NOS
     ICD9_CODE          SHORT_TITLE
4461     42731  Atrial fibrillation
     ICD9_CODE               SHORT_TITLE
4373     41401  Crnry athrscl natve vssl
     ICD9_CODE               SHORT_TITLE
5550     51881  Acute respiratry failure
     ICD9_CODE               SHORT_TITLE
1588     25000  DMII wo cmp nt st uncntr
     ICD9_CODE              SHORT_TITLE
5912      5856  End stage renal disease
     ICD9_CODE               SHORT_TITLE
6550      5990  Urin tract infection NOS
     ICD9_CODE             SHORT_TITLE
4315     40391  Hyp kid NOS w cr kid V


In [60]:
c1_pids = DIAGNOSES_ICD.loc[DIAGNOSES_ICD['ICD9_CODE'] == '4280'].SUBJECT_ID.unique()
print(len(c1_pids))
c2_pids = DIAGNOSES_ICD.loc[DIAGNOSES_ICD['ICD9_CODE'] == '4019'].SUBJECT_ID.unique()
print(len(c2_pids))
c3_pids = DIAGNOSES_ICD.loc[DIAGNOSES_ICD['ICD9_CODE'] == '5849'].SUBJECT_ID.unique()
print(len(c3_pids))

9843
17613
7687


In [64]:
def count_selected_patients(pids):
    nums = 0
    for p_id in pids:
        if len(diagnosis_codes[p_id]) < 4:
            continue
        nums += 1
    print(nums)

count_selected_patients(c1_pids)
count_selected_patients(c2_pids)
count_selected_patients(c3_pids)

674
624
651


In [66]:
def select_patients(pids):
    out = []
    for p_id in pids:
        if len(diagnosis_codes[p_id]) < 4:
            continue
        out.append(p_id)
    return out

c1 = select_patients(c1_pids)
c2 = select_patients(c2_pids)
c3 = select_patients(c3_pids)

In [68]:
def max_visit(pids):
    counts = []
    for p_id in pids:
        counts.append(len(diagnosis_codes[p_id]))
    print(max(counts))

max_visit(c1)
max_visit(c2)
max_visit(c3)

42
42
42


In [69]:
def avg_visit(pids):
    counts = []
    for p_id in pids:
        counts.append(len(diagnosis_codes[p_id]))
    print(sum(counts)/len(counts))
avg_visit(c1)
avg_visit(c2)
avg_visit(c3)

5.551928783382789
5.551282051282051
5.723502304147465


In [139]:
# any shared particpants, mutually exclusive
c1_set = set(c1)
c2_set = set(c2)
c3_set = set(c3)

common_pids = set()
common_pids = common_pids.union(c1_set.intersection(c2_set))
common_pids = common_pids.union(c2_set.intersection(c3_set))
common_pids = common_pids.union(c3_set.intersection(c1_set))
# common_pids = common_pids.union(c1_set.intersection(c2_set, c3_set))
len(common_pids)

683

In [140]:
def remove_common_patients(pids, common):
    return list(filter(lambda x: x not in common, pids))

c1_unique = remove_common_patients(c1, common_pids)
c2_unique = remove_common_patients(c2, common_pids)
c3_unique = remove_common_patients(c3, common_pids)
print(len(c1_unique))
print(len(c2_unique))
print(len(c3_unique))

101
90
82


In [142]:
# rebalance each group
# who ever first encounter
c1_sub = []
c2_sub = []
c3_sub = []
for p_id in common_pids:
    for visit in diagnosis_codes[p_id][3:]:
        if '5849' in visit:
            c3_sub.append(p_id)
            break
        if '4019' in visit:
            c2_sub.append(p_id)
            break
        if '4280' in visit:
            c1_sub.append(p_id)
            break

print(len(c1_sub))
print(len(c2_sub))
print(len(c3_sub))

156
205
263


In [143]:
c1_clean = c1_unique + c1_sub
c2_clean = c2_unique + c2_sub
c3_clean = c3_unique + c3_sub
print(len(c1_clean))
print(len(c2_clean))
print(len(c3_clean))

print('Max Visit')
max_visit(c1_clean)
max_visit(c2_clean)
max_visit(c3_clean)

print('Average Visit')
avg_visit(c1_clean)
avg_visit(c2_clean)
avg_visit(c3_clean)

257
295
345
Max Visit
42
31
34
Average Visit
5.918287937743191
5.216949152542373
5.684057971014493


In [198]:
def unique_codes(pids):
    unique = []
    for p_id in pids:
        for visit in diagnosis_codes[p_id]:
            unique += visit
    
    print(len(set(unique)))

print('Unique Codes')
unique_codes(c1_clean)
unique_codes(c2_clean)
unique_codes(c3_clean)

Unique Codes
1703
1913
2061


In [201]:
print("Unique codes per patient")
print(1703 / 257.0)
print(1913 / 295.0)
print(2061 / 345.0)

Unique codes per patient
6.626459143968872
6.4847457627118645
5.973913043478261


In [200]:
def total_events(pids):
    unique = []
    for p_id in pids:
        for visit in diagnosis_codes[p_id]:
            unique += visit
    print(len(unique))

print('Total Events')
total_events(c1_clean)
total_events(c2_clean)
total_events(c3_clean)

Total Events
22573
21775
30498


In [202]:
print("Avg # event per patient")
print(22573 / 257.0)
print(21775 / 295.0)
print(30498 / 345.0)

Avg # event per patient
87.83268482490273
73.8135593220339
88.4


In [144]:
# verify that each group is mutually exclusive
c1_set = set(c1_clean)
c2_set = set(c2_clean)
c3_set = set(c3_clean)

common_pids = set()
common_pids = common_pids.union(c1_set.intersection(c2_set))
common_pids = common_pids.union(c2_set.intersection(c3_set))
common_pids = common_pids.union(c3_set.intersection(c1_set))
len(common_pids)

0

In [146]:
combined = c1_clean + c2_clean + c3_clean
print('Number of selected patient: {}'.format(len(combined)))
combined_code  = []
combined_label = []

# extract dataframe only contains the selected patients
combined_df = DIAGNOSES_ICD[DIAGNOSES_ICD['SUBJECT_ID'].isin(combined)]
combined_code = combined_df['ICD9_CODE'].unique().tolist()
combined_pids = combined_df['SUBJECT_ID'].unique().tolist()
print(len(combined_pids))

Number of selected patient: 897
897


#### create mapping for diagnosis code label
Raw ICD-9 to with prefix-ICD-9

In [147]:
text_code = []
for icd_code in combined_code:
    text_code.append('DIAG_{}'.format(icd_code))

tcode_to_idx = {}
cidx_to_type = {}
for idx, tcode in tqdm(enumerate(text_code)):
    tcode_to_idx[tcode] = idx
    cidx_to_type[idx] = tcode

2835it [00:00, 692133.40it/s]


In [194]:
print(tcode_to_idx['DIAG_{}'.format(4280)])
print(tcode_to_idx['DIAG_{}'.format(4019)])
print(tcode_to_idx['DIAG_{}'.format(5849)])

53
29
32


#### remap patient id

In [150]:
opid_to_idx = {} # original patient id to new index
pidx_to_opid = {}
for idx, old_id in enumerate(combined):
    opid_to_idx[old_id] = idx
    pidx_to_opid[idx] = old_id

#### create lists to identify patients for each group
getting the new index
#### 4280 4019 5849

In [151]:
g1_pids = [opid_to_idx[p_id] for p_id in c1_clean] # 4280
g2_pids = [opid_to_idx[p_id] for p_id in c2_clean] # 4019
g3_pids = [opid_to_idx[p_id] for p_id in c3_clean] # 5849
group_pids = {'4280': g1_pids, '4019': g2_pids, '5849': g3_pids}

#### repack the dataset

In [152]:
# the reason we need for 
dataset_clean = {}
for old_pid in combined:
    raw_seqs = diagnosis_codes[old_pid]
    curr_patient = []
    for visit in raw_seqs:
        # visit_tcode = ['DIAG_{}'.format(x) for x in visit]
        
        # map raw code to new index
        visit_mod = [tcode_to_idx['DIAG_{}'.format(x)] for x in visit]
        curr_patient.append(visit_mod)
    
    dataset_clean[opid_to_idx[old_pid]] = curr_patient

# convert from dict to array
seqs = [value for key, value in sorted(dataset_clean.items(), key=lambda item: item[0])]
pids = [key   for key, value in sorted(dataset_clean.items(), key=lambda item: item[0])]

In [153]:
with open('dataset/t1/pids.pkl', 'wb') as f:
    pickle.dump(pids, f, protocol=4)
with open('dataset/t1/seqs.pkl', 'wb') as f:
    pickle.dump(seqs, f, protocol=4)
with open('dataset/t1/rtypes.pkl', 'wb') as f:
    pickle.dump(cidx_to_type, f, protocol=4)
with open('dataset/t1/group_pids.pkl', 'wb') as f:
    pickle.dump(group_pids, f, protocol=4)