In [1]:
import re
import csv
import operator
import json
from pathlib import Path
from tqdm import tqdm
from collections import Counter

import pandas as pd
import numpy as np
from striprtf.striprtf import rtf_to_text

MIMIC_3_DIR = '../../data/mimic-3/raw/physionet.org/files/mimiciii/1.4'
D_ICD_DIAGNOSES = 'D_ICD_DIAGNOSES.csv.gz'
D_ICD_PROCEDURES = 'D_ICD_PROCEDURES.csv.gz'
DIAGNOSES_ICD = 'DIAGNOSES_ICD.csv.gz'
PROCEDURES_ICD = 'PROCEDURES_ICD.csv.gz'
NOTEEVENTS = 'NOTEEVENTS.csv.gz'

PREPROCESSED_DIR = '../../data/mimic-3/preprocessed'
PREGENERATED_DIR = '../../data/mimic-3/pregenerated'
ALL_CODES = 'ALL_CODES.csv'
DISCHARGE_SUMMARIES = 'disch_full.csv'
ALL_CODES_FILTERED = 'ALL_CODES_filtered.csv'
NOTES_LABELED = 'notes_labeled.csv'
DISCHARGE_SPLIT_TRAIN = 'disch_train_split.csv'
DISCHARGE_SPLIT_DEV = 'disch_dev_split.csv'
DISCHARGE_SPLIT_TEST = 'disch_test_split.csv'

ICD9_DIR = '../../data/mimic-3/raw_cdc'
DIAGNOSES_TABULAR_LIST = 'Dtab12.rtf'
PROCEDURES_TABULAR_LIST = 'Ptab12.RTF'

In [2]:
class ICD_Node:
    def __init__(
        self, 
        code, 
        classname='', 
        short_classname=None,
        assignable=False,
        parent=None,
        **kwarg
    ):
        self.code = code
        self.classname = classname
        self.short_classname = short_classname
        self.description = []
        self.assignable = assignable
        self.parent = parent
        
    def todict(self):
        return {
            'code' : self.code,
            'classname' : self.classname,
            'short_classname' : self.short_classname,
            'description' : self.description,
            'assignable' : self.assignable,
            'parent' : self.parent,
        }
    
    @classmethod
    def fromdict(cls, data):
        node = cls(**data)
        node.description.extend(data['description'])
        return node        
    
    def __repr__(self):
        return str(self.todict())

In [3]:

CATEGORY_REGEX = r'^(\d+\.)?\s?([^a-z]*?)\s?+\(([EV\d\-]+)\)$'
SUPER_REGEX = r'^([EV]?\d{2,3})\s?\t\s?([^a-z]{1}.*)'
SUB_REGEX = r'^([EV]?\d{2,3}\.\d{1,2})\s*([^a-z]{1}.*)'
# CONTENT_REGEX = r'(Note|Includes|Excludes)\:\t?\s*(.*)'

def get_parent_code(code):
    code = code[:-1]
    if code[-1] == '.':
        code = code[:-1]
    return code

def chunk_data(data, regex):
    chunk_ids = [i for i, line in enumerate(data) if re.findall(regex, line)]+[None]
    result = []
    if chunk_ids:
        from_id = chunk_ids[0]
        for to_id in chunk_ids[1:]:
            sublist = data[from_id:to_id]
            result.append(sublist)
            from_id = to_id
    return result, chunk_ids

def get_diagnoses_raw_data():
    f = Path(f'{ICD9_DIR}/{DIAGNOSES_TABULAR_LIST}').open()
    while True:
        line = rtf_to_text(f.readline())
        if 'ICD-9-CM Tabular List of Diseases (FY12)' in line:
            break
    raw_data = [rtf_to_text(line if line[0] != '}' else line[1:]) for line in f.readlines()]
    raw_data = ''.join(raw_data).strip().split('\n')[:-1]
    raw_data[0] += ' (000-999)'
    raw_data.insert(0, 'CLASSIFICATION OF DIAGNOSIS (000-V99)')
    return raw_data

def get_procedures_raw_data():
    f = Path(f'{ICD9_DIR}/{PROCEDURES_TABULAR_LIST}').open()
    while True:
        line = rtf_to_text(f.readline())
        if 'ICD-9-CM TABULAR LIST OF PROCEDURES (FY12)' in line:
            break
    raw_data = [rtf_to_text(line if line[0] != '}' else line[1:]) for line in f.readlines()]
    raw_data = ''.join(raw_data).strip().split('\n')[:-1]
    raw_data[0] += ' (00-99)'
    return raw_data

def preprocess_rtf_icd9(is_diag=True):
    ### prepare raw data
    if is_diag:
        raw_data = get_diagnoses_raw_data()
    else:
        raw_data = get_procedures_raw_data()
    
    ### chunk data for category classes
    category_classes, _ = chunk_data(raw_data, CATEGORY_REGEX)

    icd9_dict = dict()
    code_stack = []
    ### for each category class lines
    for cat_lines in category_classes:
        ### create category class node
        cat_title, *cat_body = cat_lines
        _, classname, code = re.findall(CATEGORY_REGEX, cat_title)[0]
        node = ICD_Node(code, classname=classname)
        
        ### chunk data for super classes
        super_classes, super_ids = chunk_data(cat_body, SUPER_REGEX)
        cat_contents = cat_body[:super_ids[0]] if super_classes else cat_body
        ### set category class description
        node.description.extend(cat_contents)
        
        ### set category class parent
        startcode, *_ = code.split('-')
        if not code_stack:
            code_stack.append(code)
        else:
            while node.parent is None:
                tmp_startcode, *tmp_endcode = code_stack[-1].split('-')
                if tmp_startcode <= startcode and (tmp_endcode and startcode <= tmp_endcode[0]):
                    node.parent = code_stack[-1]
                    code_stack.append(code)
                else:
                    code_stack.pop()
    
        ### asign category class node
        icd9_dict[code] = node
        
        ### for each super class lines
        for super_lines in super_classes:
            ### create super class node and set parent
            super_title, *super_body = super_lines
            super_code, super_classname = re.findall(SUPER_REGEX, super_title.strip())[0]
            
            if super_code != code:
                super_node = ICD_Node(super_code, classname=super_classname)
                super_node.parent = node.code
            else:
                super_node = node
                super_node.classname = super_classname
            
            ### chunk data for sub classes
            sub_classes, sub_ids = chunk_data(super_body, SUB_REGEX)
            
            super_contents = super_body[:sub_ids[0]] if sub_classes else super_body
            ### set category class description
            super_node.description.extend(super_contents)
            
            ### asign super class node
            icd9_dict[super_code] = super_node
            ### for each category class lines
            for sub_lines in sub_classes:
                ### create category class node
                sub_title, *sub_body = sub_lines
                sub_code, sub_classname = re.findall(SUB_REGEX, sub_title)[0]
                sub_node = ICD_Node(sub_code, classname=sub_classname)
                sub_node.description.extend(sub_body)
                sub_node.parent = get_parent_code(sub_code)
                icd9_dict[sub_code] = sub_node
    
    if is_diag:
        d_icd_filename = D_ICD_DIAGNOSES
        ### handle irregulars in diagnoses
        category = icd9_dict['V30-V39']
        digit_description = dict()
        for d in category.description[2:5]:
            k, v = d.split('\t')
            digit_description[k] = v
        for node in range(30,40):
            try:
                super_code = f'V{node}'
                parent = icd9_dict[super_code]
                for d, name in digit_description.items():
                    sub_code = f'{parent.code}.{d}'
                    sub_classname = f'{parent.classname}, {name}'
                    sub_node = ICD_Node(sub_code, 
                                        classname=sub_classname,
                                        parent=parent.code)
                    icd9_dict[sub_code] = sub_node
            except:
                continue  
        icd9_dict['719.70'] = ICD_Node( '719.70', 
                                        classname='Difficulty in walking involving joint site unspecified',
                                        assignable=True, 
                                        parent='719.7')
        icd9_dict['719.70'].description=['Deprecated code. Changed to (719.7) in 2003, but exist in mimic-iii.']
        
    else:
        d_icd_filename = D_ICD_PROCEDURES
        ### handle irregulars in procedures
        icd9_dict['14.8'] = ICD_Node('14.8', classname='Operations On Epiretinal Visual Prosthesis', parent='14')
        deprecated_codes = ['36.01', '36.02', '36.05']
        deprecated_classname = [
            'Single vessel percutaneous transluminal coronary angioplasty [PTCA] or coronary atherectomy without mention of thrombolytic agent', 
            'Single vessel percutaneous transluminal coronary angioplasty [PTCA] or coronary atherectomy with mention of thrombolytic agent', 
            'Multiple vessel percutaneous transluminal coronary angioplasty [PTCA] or coronary atherectomy performed during the same operation, with or without mention of thrombolytic agent']
        for code, classname in zip(deprecated_codes, deprecated_classname):
            icd9_dict[code] = ICD_Node(code, classname=classname, assignable=True, parent='36.0')
            icd9_dict[code].description=['Deprecated code. Changed to (00.66) code in 2005, but exist in mimic-iii.']
        
    converters = {"ROW_ID":int, "ICD9_CODE":str, "SHORT_TITLE":str, "LONG_TITLE":str}
    icd9_diagnoses_title = pd.read_csv(f'{MIMIC_3_DIR}/{d_icd_filename}', converters=converters)
    
    for _, data in icd9_diagnoses_title.iterrows():
        code = data.ICD9_CODE
        code = reformat(code, is_diag)
        if code in icd9_dict:
            node = icd9_dict[code]
        else:
            parent_code = get_parent_code(code)
            node = ICD_Node(code, 
                            data.LONG_TITLE,
                            parent=parent_code)
            
        node.classname = data.LONG_TITLE
        node.short_classname = data.SHORT_TITLE
        node.assignable = True
        icd9_dict[code] = node                 
        if node.parent not in icd9_dict:
            node.parent = get_parent_code(node.parent)

    print('DIAGNOSES:' if is_diag else 'PROCEDURES:', len(icd9_dict), 'nodes preprocessed.')
    return icd9_dict

def reformat(code, is_diag):
    if is_diag:
        sup_code_len = 4 if code.startswith('E') else 3
    else:
        sup_code_len = 2
    if len(code)>sup_code_len:
        i = sup_code_len
        sup_code, sub_code = code[:i], code[i:]
        code = f'{sup_code}.{sub_code}'
    return code 

In [4]:
## create icd_9_cm file

icd9_diagnoses = preprocess_rtf_icd9(is_diag=True)
icd9_procedures = preprocess_rtf_icd9(is_diag=False)
print(len(icd9_diagnoses) + len(icd9_procedures))
icd9 = dict()
icd9['ROOT'] = ICD_Node('ROOT', 'ICD-9-CM')
icd9['ROOT'].description.append('ICD-9-CM in MIMIC-III')

list(icd9_diagnoses.values())[0].parent = 'ROOT'
for v in icd9_diagnoses.values():
    v.description.insert(0, '[DIAGNOSE]')
icd9.update(icd9_diagnoses)

list(icd9_procedures.values())[0].parent = 'ROOT'
for v in icd9_procedures.values():
    v.description.insert(0, '[PROCEDURE]')
icd9.update(icd9_procedures)
len(icd9), len([k for k, v in icd9.items() if v.assignable])

## hierarchy test
hierarchy_nodes = set()
for k, v in tqdm(icd9.items()):
    if v.assignable:
        tmp = v
        hierarchy_nodes.add(tmp.code)
        while tmp.parent is not None:
            tmp = icd9[tmp.parent]
            hierarchy_nodes.add(tmp.code)
print(len(hierarchy_nodes))

f_hierarchy = Path(f'{PREPROCESSED_DIR}/ICD-9-CM.jsonl').open('w')
for v in icd9.values():
    f_hierarchy.write(f'{json.dumps(v.todict())}\n')
f_hierarchy.close()

DIAGNOSES: 17720 nodes preprocessed.
PROCEDURES: 4671 nodes preprocessed.
22391


100%|██████████| 22392/22392 [00:00<00:00, 1117443.07it/s]

22392





In [5]:
## read icd_9_cm file
icd9 = dict()

f_hierarchy = Path(f'{PREPROCESSED_DIR}/ICD-9-CM.jsonl').open()
while f_hierarchy.readable():
    line = f_hierarchy.readline().strip()
    if line:
        node = ICD_Node.fromdict(json.loads(line))
        icd9[node.code] = node
    else:
        break
f_hierarchy.close()
len(icd9)

22392

In [9]:
max_depth = 0
max_node = None
for k, v in icd9.items():
    if v.assignable:
        depth = 1
        child = v
        while child.parent:
            child = icd9[child.parent]
            depth += 1
        if max_depth < depth:
            max_depth = depth
            max_node = k
max_depth

9

In [11]:
child = icd9[max_node]
while child.parent:
    print(child)
    child = icd9[child.parent]
    

{'code': '290.10', 'classname': 'Presenile dementia, uncomplicated', 'short_classname': 'Presenile dementia', 'description': ['[DIAGNOSE]', 'Presenile dementia:', 'NOS', 'simple type'], 'assignable': True, 'parent': '290.1'}
{'code': '290.1', 'classname': 'Presenile dementia', 'short_classname': None, 'description': ['[DIAGNOSE]', 'Brain syndrome with presenile brain disease', 'Excludes:\tarteriosclerotic dementia (290.40-290.43)', 'dementia associated with other cerebral conditions (294.10-294.11)'], 'assignable': False, 'parent': '290'}
{'code': '290', 'classname': 'Dementias', 'short_classname': None, 'description': ['[DIAGNOSE]', 'Code first the associated neurological condition', 'Excludes:\tdementia due to alcohol (291.0-291.2)', 'dementia due to drugs (292.82)', 'dementia not classified as senile, presenile, or arteriosclerotic (294.10-294.11)', 'psychoses classifiable to 295-298 occurring in the senium without dementia or delirium (295.0-298.8)', 'senility with mental changes o

In [6]:
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.feather as feather

PAD_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"

ID_COLUMN = "_id"
TEXT_COLUMN = "text"
TARGET_COLUMN = "target"
SUBJECT_ID_COLUMN = "subject_id"
code_column_names = ["icd9_diag", "icd9_proc"]

df = feather.read_feather(
    "../../data/mimic-3/medical-coding-reproducibility/mimiciii_clean/mimiciii_clean.feather",
    columns=[
        ID_COLUMN,
        TEXT_COLUMN,
        TARGET_COLUMN,
        "num_words",
        "num_targets",
    ]
    + code_column_names,
)
splits = feather.read_feather(
    # "../../data/mimic-3/medical-coding-reproducibility/mimiciii_clean/mimiciii_clean_subsplit_0.2.feather",
    "../../data/mimic-3/medical-coding-reproducibility/mimiciii_clean/mimiciii_clean_splits.feather"
)
df = df.merge(splits, on=ID_COLUMN, how="inner")
schema = pa.schema(
    [
        pa.field(ID_COLUMN, pa.int64()),
        pa.field(TEXT_COLUMN, pa.large_utf8()),
        pa.field(TARGET_COLUMN, pa.list_(pa.large_string())),
        pa.field("split", pa.large_string()),
        # pa.field("num_words", pa.int64()),
        # pa.field("num_targets", pa.int64()),
    ]
)
final_df = pa.Table.from_pandas(
    df[
        [
        ID_COLUMN,
        TEXT_COLUMN,
        TARGET_COLUMN,
        "split",
        # "num_words",
        # "num_targets",
        ]
    ], schema=schema)
new_names = ["did", "text", "labels", "split"]
final_df = final_df.rename_columns(new_names)
print(final_df.column_names)
print(final_df.num_rows)


['did', 'text', 'labels', 'split']
52712


In [7]:
train_name = f'{PREGENERATED_DIR}/mimic3_clean_train.jsonl'
dev_name = f'{PREGENERATED_DIR}/mimic3_clean_dev.jsonl'
test_name = f'{PREGENERATED_DIR}/mimic3_clean_test.jsonl'


train_file = Path(train_name).open('w')
train_data = final_df.filter(pc.field("split") == "train")
train_reader = train_data.to_reader(max_chunksize=1)
for batch in train_reader:
    data = {k:v[0] for k,v in batch.to_pydict().items()}
    del data['split']
    train_file.write(f"{json.dumps(data)}\n")
train_file.close()

dev_file = Path(dev_name).open('w')
dev_data = final_df.filter(pc.field("split") == "val")
dev_reader = dev_data.to_reader(max_chunksize=1)
for batch in dev_reader:
    data = {k:v[0] for k,v in batch.to_pydict().items()}
    del data['split']

    dev_file.write(f"{json.dumps(data)}\n")
dev_file.close()

test_file = Path(test_name).open('w')
test_data = final_df.filter(pc.field("split") == "test")
test_reader = test_data.to_reader(max_chunksize=1)
for batch in test_reader:
    data = {k:v[0] for k,v in batch.to_pydict().items()}
    del data['split']

    test_file.write(f"{json.dumps(data)}\n")
test_file.close()

In [8]:
## create target_labels file

target_labels = set()
for target in final_df['labels']:
    target_labels.update(target.as_py())
print(len(target_labels))

f = Path(PREGENERATED_DIR,'target_labels.txt').open('w')
for label in list(target_labels):
    f.write(label+'\n')
f.close()

3681


In [9]:
## create simple label
labels = set()

def update_labels(label):
    label_node = icd9[label]
    labels.add(label_node.code)
    if p:=label_node.parent:
        update_labels(p)
        
for label in target_labels:
    update_labels(label)
    
print(len(labels))

f = Path(PREGENERATED_DIR,'labels.txt').open('w')
labels = [v for v in icd9.values() if v.code in labels]
for label in labels:
    f.write(f"{label.code}\t{label.classname}\n")
f.close()

5456


In [10]:
## create pcmap file

f = Path(PREGENERATED_DIR,'parent_child_map.txt').open('w')
i=0
for label in labels:
    if label.parent:
        i += 1
        f.write(f'{label.parent}\t{label.code}\n')
f.close()
i

5455

In [11]:
# ## create simple label
# labels = set()

# def update_labels(label):
#     label_node = icd9[label]
#     labels.add(label_node.code)
#     if p:=label_node.parent:
#         update_labels(p)
        
# for label in target_labels:
#     update_labels(label)
    
# print(len(labels))

# f = Path(PREGENERATED_DIR,'labels_simple.txt').open('w')
# labels = [v for v in icd9.values() if v.code in labels]
# for label in labels:
#     if '-' not in label.code:
#         f.write(f"{label.code}\t{label.classname}\n")
# f.close()

# ## create pcmap file

# f = Path(PREGENERATED_DIR,'parent_child_map_simple.txt').open('w')
# i=0
# for label in labels:
#     if '-' not in label.code:
#         i += 1
#         if label.parent:
#             if '-' in label.parent:
#                 f.write(f'ROOT\t{label.code}\n')
#             else:
#                 f.write(f'{label.parent}\t{label.code}\n')
# f.close()
# i