The script is used to clean MIMIC-III dataset, 
* remove procedure codes in MIMIC-III dataset
* nan label in MIMIC-III dataset
* labels not appearing in description text

In [15]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from collections import Counter, defaultdict
import csv
import math
import operator
import json
import pandas as pd
import os

### Load Dataset

In [16]:
### root of this paper ###
save_path = '../dataset/MIMIC3'

with open(os.path.join(save_path, 'ICD9CODES.json'), 'r') as f:
    label2desc = json.load(f)
label = list(label2desc.keys())

with open(os.path.join(save_path, 'ICD9_descriptions.json'), 'r') as f:
    ICD9_desc = json.load(f)

In [17]:
### root of caml-mimic-master/mimicdata/mimic3 ###
MIMIC_3_DIR = 'your root/caml-mimic-master/mimicdata/mimic3'

train = pd.read_csv(os.path.join(MIMIC_3_DIR, 'train_full.csv'))
val = pd.read_csv(os.path.join(MIMIC_3_DIR, 'dev_full.csv'))
test = pd.read_csv(os.path.join(MIMIC_3_DIR, 'test_full.csv'))

In [18]:
print(len(train))
print(len(val))
print(len(test))

47723
1631
3372


### Get ICD9 procedure codes

In [20]:
import sys
### root of caml-mimic-master ###
sys.path.append(MIMIC_3_DIR[:-17])
import datasets
import log_reg
from dataproc import extract_wvs
from dataproc import get_discharge_summaries
from dataproc import concat_and_split
from dataproc import build_vocab
from dataproc import vocab_index_descriptions
from dataproc import word_embeddings

dfproc = pd.read_csv('%s/PROCEDURES_ICD.csv' % MIMIC_3_DIR)
dfproc['absolute_code'] = dfproc.apply(lambda row: str(datasets.reformat(str(row[4]), False)), axis=1)
proc_ICD9 = sorted(list(set(dfproc['absolute_code'].values)))

### Drop labels
* caculate how many labels to be droped
* drop and re-organize the dataframe

In [21]:
label_dict = dict()

for i in range(len(train)):
    labels = train['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            label_dict[l] = label_dict.get(l, 0) + 1
    else:
        print('na label found in: ', i)
        
        
for i in range(len(val)):
    labels = val['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            label_dict[l] = label_dict.get(l, 0) + 1
    else:
        print('na label found in: ', i)
        
        
for i in range(len(test)):
    labels = test['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            label_dict[l] = label_dict.get(l, 0) + 1
    else:
        print('na label found in: ', i)
        
num = 0 
drop_label = []
for l, cnt in label_dict.items():
    ### remove ICD9 codes don't have a description
    ### remove the ICD9 procedure codes
    if l not in ICD9_desc or l in proc_ICD9:
        drop_label.append(l)
        num += 1
print('#labels to be drop: ', num)
print('#total labels: ', len(label_dict.keys()))

na label found in:  3368
na label found in:  25712
na label found in:  28087
na label found in:  29322
#labels to be drop:  2253
#total labels:  8921


In [22]:
def clean_df(df, drop_list):
    df_ = []
    for i in tqdm(range(len(df))):
        sub_id = df['SUBJECT_ID'][i]
        hadm_id = df['HADM_ID'][i]
        text = df['TEXT'][i]
        labels = df['LABELS'][i]
        length = df['length'][i]
        new_label = []
        if not pd.isna(labels):
            for l in labels.split(';'):
                if l in drop_list or l in new_label:
                    continue
                else:
                    new_label.append(l.strip())
                string = ';'.join(new_label)
        if new_label:
            row = [sub_id, hadm_id, text, string, length]
            df_.append(row)
    return df_
        

clean_train_list = clean_df(train, drop_label)
clean_val_list = clean_df(val, drop_label)
clean_test_list = clean_df(test, drop_label)

clean_train = pd.DataFrame(clean_train_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])
clean_val = pd.DataFrame(clean_val_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])
clean_test = pd.DataFrame(clean_test_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])

clean_train.dropna(subset=['LABELS'], inplace=True)
clean_train.reset_index(drop=True, inplace=True)
clean_val.dropna(subset=['LABELS'], inplace=True)
clean_val.reset_index(drop=True, inplace=True)
clean_test.dropna(subset=['LABELS'], inplace=True)
clean_test.reset_index(drop=True, inplace=True)

100%|█████████████████████████████████████████████████████████| 47723/47723 [00:06<00:00, 7517.43it/s]
100%|███████████████████████████████████████████████████████████| 1631/1631 [00:00<00:00, 6579.63it/s]
100%|███████████████████████████████████████████████████████████| 3372/3372 [00:00<00:00, 6311.65it/s]


### sanity check

In [23]:
clean_label_dict = dict()

for i in range(len(clean_train)):
    labels = clean_train['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
        
for i in range(len(clean_val)):
    labels = clean_val['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
        
for i in range(len(clean_test)):
    labels = clean_test['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
assert len(clean_label_dict.keys()) == len(label_dict.keys()) - num

### save_clean_train/val/test

In [24]:
clean_train.to_csv(os.path.join(save_path, 'clean_train.csv'))
clean_val.to_csv(os.path.join(save_path, 'clean_val.csv'))
clean_test.to_csv(os.path.join(save_path, 'clean_test.csv'))

In [25]:
df = pd.concat([clean_train, clean_val, clean_test]).reset_index(drop=True)
df

Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,LABELS,length
0,158,169433,admission date discharge date date of birth se...,532.40;493.20;V45.81;412;401.9,51
1,2896,178124,name known lastname known firstname unit no nu...,211.3;427.31;578.9;560.1;496;584.9;428.0;276.5...,55
2,6495,139808,admission date discharge date date of birth se...,998.59;998.32;905.4;E929.0;041.85,60
3,3564,117638,admission date discharge date service doctor l...,038.49;041.6;785.59;518.81;507.0;592.1;591;276...,68
4,7995,190945,admission date discharge date date of birth se...,440.22;492.8;401.9;714.0,74
...,...,...,...,...,...
52716,96777,176399,admission date discharge date date of birth se...,480.1;996.85;780.39;117.7;204.01;117.3;078.5;2...,5890
52717,95323,142423,admission date discharge date date of birth se...,518.81;486;507.0;292.0;276.1;453.42;112.0;292....,6116
52718,91074,106110,admission date discharge date date of birth se...,486;518.81;584.9;518.0;491.21;428.32;112.2;276...,6117
52719,92316,158581,admission date discharge date date of birth se...,427.41;785.51;570;807.4;584.9;861.21;276.2;790...,6227


### calculate label freq

In [26]:
def cnt_instance_per_label(df, column_name):
    label_cnt = {}
#     column_name = 'ICD9_DIAG'
    for i in range(len(df)):
        if type(df[column_name][i]) == float:
            continue
        ps = df[column_name][i].strip()
        for p in ps.split(';'):
            p = p.strip()
            label_cnt[p] = label_cnt.get(p, 0) + 1
    return label_cnt

def sortBy(l1, l2, reverse=True):
    x_axis, y_axis = [], []
    if l1 and l2:
        zipped = zip(l1, l2)
        sort_zipped = sorted(zipped, key=lambda x:(x[1], x[0]), reverse=reverse)
        result = zip(*sort_zipped)
        x_axis, y_axis = [list(x) for x in result]
    return x_axis, y_axis

In [27]:
label_cnt = cnt_instance_per_label(df, column_name='LABELS')
total_num = list(label_cnt.values())
ICD9CODE = list(label_cnt.keys())

sort_ICD9CODE, sort_total_num = sortBy(ICD9CODE, total_num, reverse=True)

sorted_label_cnt = {}
for i in range(len(sort_ICD9CODE)):
    code = sort_ICD9CODE[i]
    num = sort_total_num[i]
    sorted_label_cnt[code] = num

with open(os.path.join(save_path, 'MIMIC3_Label_cnt.json'), 'w') as f:
    json.dump(sorted_label_cnt, f, indent=4)

### Calculate everage number of disease in MIMIC3 dataset

In [28]:
import pandas as pd
import json
import numpy as np
import os


clean_train = pd.read_csv(os.path.join(save_path, "clean_train.csv"))
clean_val = pd.read_csv(os.path.join(save_path, "clean_val.csv"))
clean_test = pd.read_csv(os.path.join(save_path, "clean_test.csv"))


df = pd.concat([clean_train, clean_val, clean_test]).reset_index(drop=True)
all_list_len = []
for ls in df['LABELS']:
    l_list = [l for l in ls.split(';')]
    all_list_len.append(len(l_list))
np.mean(all_list_len)

11.724151666318924

### label hierarchy dict

In [29]:
from collections import defaultdict
hierarchy2ICD9CODE = defaultdict(list)

for ICD9_CODE, description in label2desc.items():
    
    if ICD9_CODE.startswith('E') or ICD9_CODE.startswith('V'):
        hierarchy2ICD9CODE['external causes of injury and supplemental classification'].append(ICD9_CODE)
    
    else:
        if 0 <= float(ICD9_CODE) < 140:
            hierarchy2ICD9CODE['infectious and parasitic diseases'].append(ICD9_CODE)
        elif 140 <= float(ICD9_CODE) < 240:
            hierarchy2ICD9CODE['neoplasms'].append(ICD9_CODE)
        elif 240 <= float(ICD9_CODE) < 280:
            hierarchy2ICD9CODE['endocrine, nutritional and metabolic diseases, and immunity disorders'].append(ICD9_CODE)
        elif 280 <= float(ICD9_CODE) < 290:
            hierarchy2ICD9CODE['diseases of the blood and blood-forming organs'].append(ICD9_CODE)
        elif 290 <= float(ICD9_CODE) < 320:
            hierarchy2ICD9CODE['mental disorders'].append(ICD9_CODE)
        elif 320 <= float(ICD9_CODE) < 390:
            hierarchy2ICD9CODE['diseases of the nervous system and sense organs'].append(ICD9_CODE)
        elif 390 <= float(ICD9_CODE) < 460:
            hierarchy2ICD9CODE['diseases of the circulatory system'].append(ICD9_CODE)
        elif 460 <= float(ICD9_CODE) < 520:
            hierarchy2ICD9CODE['diseases of the respiratory system'].append(ICD9_CODE)
        elif 520 <= float(ICD9_CODE) < 580:
            hierarchy2ICD9CODE['diseases of the digestive system'].append(ICD9_CODE)
        elif 580 <= float(ICD9_CODE) < 630:
            hierarchy2ICD9CODE['diseases of the genitourinary system'].append(ICD9_CODE)
        elif 630 <= float(ICD9_CODE) < 680:
            hierarchy2ICD9CODE['complications of pregnancy, childbirth, and the puerperium'].append(ICD9_CODE)
        elif 680 <= float(ICD9_CODE) < 710:
            hierarchy2ICD9CODE['diseases of the skin and subcutaneous tissue'].append(ICD9_CODE)
        elif 710 <= float(ICD9_CODE) < 740:
            hierarchy2ICD9CODE['diseases of the musculoskeletal system and connective tissue'].append(ICD9_CODE)
        elif 740 <= float(ICD9_CODE) < 760:
            hierarchy2ICD9CODE['congenital anomalies'].append(ICD9_CODE)
        elif 760 <= float(ICD9_CODE) < 780:
            hierarchy2ICD9CODE['certain conditions originating in the perinatal period'].append(ICD9_CODE)
        elif 780 <= float(ICD9_CODE) < 800:
            hierarchy2ICD9CODE['symptoms, signs, and ill-defined conditions'].append(ICD9_CODE)
        elif 800 <= float(ICD9_CODE) < 1000:
            hierarchy2ICD9CODE['injury and poisoning'].append(ICD9_CODE)
        else:
            print('anomaly code {}'.format(ICD9_CODE))
            break
            
            
ICD9CODE2hierarchy = {}
for hier, ICD9_CODES in hierarchy2ICD9CODE.items():
    for ICD9_CODE in ICD9_CODES:
        ICD9CODE2hierarchy[ICD9_CODE] = hier

with open(os.path.join(save_path, 'p2hier.json'), 'w') as f:
    json.dump(ICD9CODE2hierarchy, f, indent=4)
    
with open(os.path.join(save_path, 'hier2p.json'), 'w') as f:
    json.dump(hierarchy2ICD9CODE, f, indent=4)

### generate pretrained embedding

this is to reproduce SOTA (ZAGCNN).

**you need to download the BioWordVec_PubMed_MIMICIII_d200.vec.bin to run the following codes**

In [9]:
from gensim.models.keyedvectors import KeyedVectors
import nltk
import re
import pickle
import json
import numpy as np
from tqdm import tqdm
nltk.download('stopwords')
from nltk.corpus import stopwords
cahcedStopwords = stopwords.words('english')
from nltk.tokenize import word_tokenize
from collections import defaultdict

def getEmbedding(entity, desc, model):
    embed_dict = {}
    no_embed_list = []
    for each in tqdm(entity):
        if each in desc:
            entity = desc[each].lower()
        else:
            entity = each
#             print(e)
        tokenized_word = word_tokenize(entity)
        tmp_list = []
        for w in tokenized_word:
            if w not in model:
                continue
            embedding = model[w]
            tmp_list.append(embedding)
        if tmp_list:
            embed = np.mean(np.array(tmp_list), axis=0).tolist()
            embed_dict[each] = embed
        else:
            embed_dict[each] = model['unk']
            no_embed_list.append(each)
    print(no_embed_list, len(no_embed_list))
    return embed_dict

[nltk_data] Downloading package stopwords to /home/xueren/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
### your root to BioWordVec_PubMed_MIMICIII_d200.vec.bin ###
pretrain_wv = 'your root/BioWordVec_PubMed_MIMICIII_d200.vec.bin'
model = KeyedVectors.load_word2vec_format(pretrain_wv, binary=True)

In [10]:
print("generating embedding for icd9 disease codes")


desc = ['international statistical classification of diseases and related health problems']

with open("../dataset/MIMIC3/hier2p.json", 'r') as f:
    hier2p = json.load(f)

with open("../dataset/MIMIC3/ICD9CODES.json", 'r') as f:
    desc_dict = json.load(f)
label = list(desc_dict.keys())

    
hier = list(set(list(hier2p.keys())))
desc.extend(hier)
desc.extend(label)

disease_embed_dict = getEmbedding(desc, desc_dict, model)

if not os.path.exists("./pre-trained embedding"):
    os.makedirs("./pre-trained embedding")

node_attr_cache = "../pre-trained embedding/MIMIC3_label_embedding_BioWordVec_PubMed_MIMICIII_d200.npy"
np.save(node_attr_cache, disease_embed_dict)

generating embedding for icd9 disease codes


100%|████████████████████████████████████████████████| 6687/6687 [00:00<00:00, 12340.68it/s]

['V44.51'] 1





In [11]:
len(hier)

18