In [1]:
import pickle
import os, sys
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter

os.chdir('/home/baizhiwang/gki_icd/')
sys.path[0]='../'
!pwd
from src.utils import write_pickle, read_pickle

/home/baizhiwang/gki_icd


Read Files from MIMIC-III Dataset

In [2]:
note_events = pd.read_csv("data/raw_data/mimic3/NOTEEVENTS.csv",dtype=str)

Get Offical Splits from CAML (hadm_id_list)

In [3]:
hadm_ids = []
for split in ["train", "dev", "test"]:
    split_df = pd.read_csv(f"data/splits/mimic3/{split}_full_hadm_ids.csv", header=None, dtype=str)
    hadm_ids.extend(split_df[0].tolist())
    print(split, len(split_df))

train 47723
dev 1631
test 3372


Get Coresponding Clinical Note for Each Hadm ID

In [4]:
hadm2subject = {}
hadm2report = {hadm_id: [] for hadm_id in hadm_ids}
for i, row in tqdm(note_events.iterrows()):
    hadm_id, subject_id = row["HADM_ID"], row["SUBJECT_ID"]
    if hadm_id in hadm2report and row["CATEGORY"] == "Discharge summary":
        hadm2report[hadm_id].append(row["TEXT"])
        hadm2subject[hadm_id] = subject_id

2083180it [01:15, 27515.73it/s]


In [5]:
def reformat(code, is_diag):
    """
    Put a period in the right place because the MIMIC-3 data files exclude them.
    Generally, procedure codes have dots after the first two digits,
    while diagnosis codes have dots after the first three digits.
    """
    # code = str(code)
    code = "".join(code.split("."))
    if is_diag:
        if code.startswith("E"):
            if len(code) > 4:
                code = code[:4] + "." + code[4:]
        else:
            if len(code) > 3:
                code = code[:3] + "." + code[3:]
    else:
        code = code[:2] + "." + code[2:]
    return code

In [6]:
dfproc = pd.read_csv(f"data/raw_data/mimic3/PROCEDURES_ICD.csv", dtype=str)
dfproc

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,944,62641,154460,3,3404
1,945,2592,130856,1,9671
2,946,2592,130856,2,3893
3,947,55357,119355,1,9672
4,948,55357,119355,2,0331
...,...,...,...,...,...
240090,228330,67415,150871,5,3736
240091,228331,67415,150871,6,3893
240092,228332,67415,150871,7,8872
240093,228333,67415,150871,8,3893


In [7]:
dfproc = pd.read_csv(f"data/raw_data/mimic3/PROCEDURES_ICD.csv")
dfproc

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,944,62641,154460,3,3404
1,945,2592,130856,1,9671
2,946,2592,130856,2,3893
3,947,55357,119355,1,9672
4,948,55357,119355,2,331
...,...,...,...,...,...
240090,228330,67415,150871,5,3736
240091,228331,67415,150871,6,3893
240092,228332,67415,150871,7,8872
240093,228333,67415,150871,8,3893


In [7]:
dfdiag = pd.read_csv(f"data/raw_data/mimic3/DIAGNOSES_ICD.csv", dtype=str)
dfproc = pd.read_csv(
    f"data/raw_data/mimic3/PROCEDURES_ICD.csv", dtype=str
)

In [8]:
len(dfproc["ICD9_CODE"].unique()), len(dfdiag["ICD9_CODE"].unique())

(2032, 6985)

In [9]:
dfdiag["absolute_code"] = dfdiag["ICD9_CODE"].apply(lambda x: str(reformat(str(x), True)))
dfproc["absolute_code"] = dfproc["ICD9_CODE"].apply(lambda x: str(reformat(str(x), False)))

In [10]:
dfcodes = pd.concat([dfdiag, dfproc])

In [11]:
len(dfcodes["absolute_code"].unique())

9017

In [12]:
hadm2icd = {hadm_id: set() for hadm_id in hadm_ids}
for i, row in tqdm(dfcodes.iterrows()):
    hadm_id, icd_code = row["HADM_ID"], row["absolute_code"]
    if hadm_id in hadm2icd and icd_code != "nan":
        hadm2icd[hadm_id].add(icd_code)

0it [00:00, ?it/s]

891142it [00:33, 26312.10it/s]


Make Dataset

In [None]:
for split in ["train", "dev", "test"]:
    split_df = pd.read_csv(f"data/splits/mimic3/{split}_full_hadm_ids.csv",header=None, dtype=str)
    split_hadm_ids = split_df[0].tolist()
    samples = []
    for hadm_id in split_hadm_ids:
        subject_id = hadm2subject[hadm_id]
        label = sorted(hadm2icd[hadm_id])
        reports = hadm2report[hadm_id]
        text = "\n".join(reports)
        # # csv: Series
        # csv_df.loc[len(csv_df)] = pd.Series(
        #     {"subject_id": subject_id, "hadm_id": hadm_id, "text": text,"label":label}
        # )
        # pickle: Dict
        sample = {}
        sample["subject_id"] = subject_id
        sample["hadm_id"] = hadm_id
        sample["text"] = text
        sample["label"] = label
        samples.append(sample)
    write_pickle(samples, f"data/mimic3/{split}.pkl")

Get Code Frequency

In [None]:


full_labels = []
counter_per_split = {}
for split in ["train", "dev", "test"]:
    split_labels = []
    split_samples = read_pickle(f"data/mimic3/{split}.pkl")
    for sample in split_samples:
        labels = sample["label"]
        split_labels.extend(labels)
        full_labels.extend(labels)
    label_count = Counter(split_labels)
    counter_per_split[split] = label_count
label_count = Counter(full_labels)

In [15]:
df = pd.DataFrame(columns=["code", "total", "train", "dev", "test"])
for code, count in label_count.items():
    df.loc[len(df)] = pd.Series(
        {
            "code": code,
            "total": count,
            "train": counter_per_split["train"].get(code, 0),
            "dev": counter_per_split["dev"].get(code, 0),
            "test": counter_per_split["test"].get(code, 0),
        }
    )
df

Unnamed: 0,code,total,train,dev,test
0,401.9,20046,17897,708,1441
1,412,3203,2866,112,225
2,44.43,590,532,20,38
3,493.20,732,647,30,55
4,532.40,269,238,10,21
...,...,...,...,...,...
8924,487.8,1,0,0,1
8925,695.51,1,0,0,1
8926,529.3,1,0,0,1
8927,014.85,1,0,0,1


In [None]:
df = df.sort_values(by="total", ascending=False)
df.to_csv("data/mimic3/distribution.csv", index=False)

MIMIC-III 50

In [17]:
top50_codes = {
    "401.9",
    "38.93",
    "428.0",
    "427.31",
    "414.01",
    "96.04",
    "96.6",
    "584.9",
    "250.00",
    "96.71",
    "272.4",
    "518.81",
    "99.04",
    "39.61",
    "599.0",
    "530.81",
    "96.72",
    "272.0",
    "285.9",
    "88.56",
    "244.9",
    "486",
    "38.91",
    "285.1",
    "36.15",
    "276.2",
    "496",
    "99.15",
    "995.92",
    "V58.61",
    "507.0",
    "038.9",
    "88.72",
    "585.9",
    "403.90",
    "311",
    "305.1",
    "37.22",
    "412",
    "33.24",
    "39.95",
    "287.5",
    "410.71",
    "276.1",
    "V45.81",
    "424.0",
    "45.13",
    "V15.82",
    "511.9",
    "37.23",
}

In [None]:
for split in ["train", "dev", "test"]:
    split_df = pd.read_csv(
        f"data/splits/mimic3/{split}_50_hadm_ids.csv", header=None, dtype=str
    )
    split_hadm_ids = split_df[0].tolist()
    samples = []
    for hadm_id in split_hadm_ids:
        subject_id = hadm2subject[hadm_id]
        label = sorted(set(hadm2icd[hadm_id]) & top50_codes)
        reports = hadm2report[hadm_id]
        text = "\n".join(reports)
        # # csv: Series
        # csv_df.loc[len(csv_df)] = pd.Series(
        #     {"subject_id": subject_id, "hadm_id": hadm_id, "text": text, "label": label}
        # )
        # pickle: Dict
        sample = {}
        sample["subject_id"] = subject_id
        sample["hadm_id"] = hadm_id
        sample["text"] = text
        sample["label"] = label
        samples.append(sample)
    write_pickle(samples, f"data/mimic3_50/{split}.pkl")

In [None]:
full_labels = []
counter_per_split = {}
for split in ["train", "dev", "test"]:
    split_labels = []
    split_samples = read_pickle(f"data/mimic3_50/{split}.pkl")
    for sample in split_samples:
        labels = sample["label"]
        split_labels.extend(labels)
        full_labels.extend(labels)
    label_count = Counter(split_labels)
    counter_per_split[split] = label_count
label_count = Counter(full_labels)

In [None]:
df = pd.DataFrame(columns=["code", "total", "train", "dev", "test"])
for code, count in label_count.items():
    df.loc[len(df)] = pd.Series(
        {
            "code": code,
            "total": count,
            "train": counter_per_split["train"].get(code, 0),
            "dev": counter_per_split["dev"].get(code, 0),
            "test": counter_per_split["test"].get(code, 0),
        }
    )
df = df.sort_values(by="train", ascending=False, ignore_index=True)
df.to_csv("data/mimic3_50/distribution.csv", index=False)
df

Unnamed: 0,code,total,train,dev,test
0,401.9,4719,3233,708,778
1,38.93,2874,2139,333,402
2,428.0,2874,2115,337,422
3,427.31,2858,1992,396,470
4,414.01,2744,1921,388,435
5,96.04,2042,1581,228,233
6,96.6,1978,1525,225,228
7,584.9,2102,1448,292,362
8,250.00,2045,1416,289,340
9,96.71,1937,1395,284,258


MIMIC-III ICD-9 Code Description (including Parent Codes)

In [None]:
mimic3_codes = set()
for split in ["train", "dev", "test"]:
    split_samples = read_pickle(f"data/mimic3/{split}.pkl")
    for sample in split_samples:
        labels = sample["label"]
        mimic3_codes.update(labels)
mimic3_codes = sorted(
    mimic3_codes
)
len(mimic3_codes)

8929

In [50]:
# rerank, diag first, proc second
diag_codes, proc_codes = [], []
for code in mimic3_codes:
    coarse_code = code.split(".")[0]
    if len(coarse_code) >= 3:
        diag_codes.append(code)
    elif len(coarse_code) == 2:
        proc_codes.append(code)
mimic3_codes = diag_codes + proc_codes

In [23]:
icd_desc1 = pd.read_csv(
    "data/icd9/ICD9_descriptions.tsv",
    sep="\t",
    header=None,
    names=["code", "desc"],
)
icd_desc1["code"] = icd_desc1["code"].apply(lambda x: x.strip())
icd_desc1["desc"] = icd_desc1["desc"].apply(lambda x: x.strip().lower())
icd_desc1 = dict(zip(icd_desc1["code"], icd_desc1["desc"]))

icd_desc2 = {}
icd_diagnoses = pd.read_csv("data/icd9/d_icd_diagnoses.csv")
for i, row in icd_diagnoses.iterrows():
    if row["icd_version"] == 9:
        code = reformat(row["icd_code"], True)
        icd_desc2[code] = row["long_title"].lower()

icd_procedures = pd.read_csv("data/icd9/d_icd_procedures.csv")
for i, row in icd_procedures.iterrows():
    if row["icd_version"] == 9:
        code = reformat(row["icd_code"], False)
        icd_desc2[code] = row["long_title"].lower()

icd2desc = icd_desc1
icd2desc.update(icd_desc2)

# These codes are missing (maybe have been removed)
icd2desc.update(
    {
        "36.01": "removal of coronary artery obstruction and insertion of stent(s)",
        "36.02": "removal of coronary artery obstruction and insertion of stent(s)",
        "36.05": "removal of coronary artery obstruction and insertion of stent(s)",
        "719.70": "difficulty in walking",
    }
)
len(icd2desc)

22366

In [None]:
mimic3_code_df = pd.DataFrame(columns=["code", "desc"])
for code in mimic3_codes:
    desc = icd2desc.get(code, "")
    mimic3_code_df.loc[len(mimic3_code_df)] = pd.Series({"code": code, "desc": desc})
mimic3_code_df.to_csv(
    "data/mimic3/code_description.csv", sep="\t", index=False, encoding="utf-8"
)

In [85]:
extra_codes = set()
for code in mimic3_codes:
    coarse_code = code.split(".")[0]
    extra_codes.add(coarse_code)
extra_codes = extra_codes - set(mimic3_codes)
extra_codes = sorted(extra_codes)
len(extra_codes)

1063

In [None]:
mimic3_extra_code_df = pd.DataFrame(columns=["code", "desc"])
for code in extra_codes:
    desc = icd2desc.get(code, "")
    mimic3_extra_code_df.loc[len(mimic3_extra_code_df)] = pd.Series(
        {"code": code, "desc": desc}
    )
mimic3_extra_code_df.to_csv(
    "data/mimic3/extra_code_description.csv", sep="\t", index=False, encoding="utf-8"
)

MIMIC-III 50 Description

In [73]:
top50_codes = ["401.9",
    "38.93",
    "428.0",
    "427.31",
    "414.01",
    "96.04",
    "96.6",
    "584.9",
    "250.00",
    "96.71",
    "272.4",
    "518.81",
    "99.04",
    "39.61",
    "599.0",
    "530.81",
    "96.72",
    "272.0",
    "285.9",
    "88.56",
    "244.9",
    "486",
    "38.91",
    "285.1",
    "36.15",
    "276.2",
    "496",
    "99.15",
    "995.92",
    "V58.61",
    "507.0",
    "038.9",
    "88.72",
    "585.9",
    "403.90",
    "311",
    "305.1",
    "37.22",
    "412",
    "33.24",
    "39.95",
    "287.5",
    "410.71",
    "276.1",
    "V45.81",
    "424.0",
    "45.13",
    "V15.82",
    "511.9",
    "37.23"]

In [None]:
mimic3_50_code_df = pd.DataFrame(columns=["code", "desc"])
for code in top50_codes:
    desc = icd2desc.get(code, "")
    mimic3_50_code_df.loc[len(mimic3_50_code_df)] = pd.Series({"code": code, "desc": desc})
mimic3_50_code_df.to_csv("data/mimic3_50/code_description.csv", sep="\t", index=False, encoding="utf-8")

In [88]:
top50_extra_codes = set()
for code in top50_codes:
    coarse_code = code.split(".")[0]
    top50_extra_codes.add(coarse_code)
top50_extra_codes = top50_extra_codes - set(top50_codes)
top50_extra_codes = sorted(top50_extra_codes)
len(top50_extra_codes)

35

In [89]:
top50_extra_code_df = pd.DataFrame(columns=["code", "desc"])
for code in top50_extra_codes:
    desc = icd2desc.get(code, "")
    top50_extra_code_df.loc[len(top50_extra_code_df)] = pd.Series(
        {"code": code, "desc": desc}
    )
top50_extra_code_df.to_csv(
    "data/mimic3_50/extra_code_description.csv", sep="\t", index=False, encoding="utf-8"
)

MIMIC3 Code -> Groups

In [77]:
diag1 = pd.read_csv("data/icd9/diagnosis_list1.tsv", sep="\t", header=None)
diag2 = pd.read_csv("data/icd9/diagnosis_list2.tsv", sep="\t", header=None)
proc = pd.read_csv("data/icd9/procedure_list.tsv", sep="\t", header=None)

In [78]:
group2def = {}
for df in [diag1, diag2, proc]:
    for i, row in df.iterrows():
        group2def[row[0]] = row[1]

In [79]:
def if_code_in_group(code, group_list):
    try:
        if code[0] == "E":
            group_list = [group for group in group_list if group[0] == "E"]
        elif code[0] == "V":
            group_list = [group for group in group_list if group[0] == "V"]
        else:
            group_list = [
                group for group in group_list if group[0] != "E" and group[0] != "V"
            ]
        for group in group_list:
            start, end = group.split("-")
            if code[0] == "E" or code[0] == "V":
                code_ = code[1:]
                start, end = start[1:], end[1:]
            else:
                code_ = code
            code_ = int(code_)
            start, end = int(start), int(end)
            if code_ >= start and code_ <= end:
                return group
    except:
        print(code)

In [None]:
mimic3_diag1, mimic3_diag2, mimic3_proc = set(), set(), set()
code_in_group = pd.DataFrame(columns=["code", "code1", "diag1", "diag2", "proc"])
code_in_group["code"] = mimic3_codes
for i, code in enumerate(mimic3_codes):
    raw_code = code
    code = raw_code.split(".")[0]
    if code != raw_code:
        code_in_group.loc[i, "code1"] = code
    if len(code) == 2:
        group = if_code_in_group(code, proc[0])
        mimic3_proc.add(group)
        code_in_group.loc[i, "proc"] = group
    else:
        group = if_code_in_group(code, diag1[0])
        mimic3_diag1.add(group)
        code_in_group.loc[i, "diag1"] = group
        group = if_code_in_group(code, diag2[0])
        mimic3_diag2.add(group)
        code_in_group.loc[i, "diag2"] = group
mimic3_diag1, mimic3_diag2, mimic3_proc = (
    sorted(mimic3_diag1),
    sorted(mimic3_diag2),
    sorted(mimic3_proc),
)
code_in_group.to_csv("data/mimic3/code_hierarchy.csv", sep="\t", index=False)

In [None]:
group_list = mimic3_diag1 + mimic3_diag2 + mimic3_proc
df = pd.DataFrame(group_list, columns=["group"])
df["name"] = df["group"].apply(lambda x: group2def[x])
df.to_csv("data/mimic3/group_description.csv", sep="\t", index=False)

MIMIC_III-50 code -> group

In [None]:
mimic3_diag1, mimic3_diag2, mimic3_proc = set(), set(), set()
code_in_group = pd.DataFrame(columns=["code", "code1", "diag1", "diag2", "proc"])
code_in_group["code"] = top50_codes
for i, code in enumerate(top50_codes):
    raw_code = code
    code = raw_code.split(".")[0]
    if code != raw_code:
        code_in_group.loc[i, "code1"] = code
    if len(code) == 2:
        group = if_code_in_group(code, proc[0])
        mimic3_proc.add(group)
        code_in_group.loc[i, "proc"] = group
    else:
        group = if_code_in_group(code, diag1[0])
        mimic3_diag1.add(group)
        code_in_group.loc[i, "diag1"] = group
        group = if_code_in_group(code, diag2[0])
        mimic3_diag2.add(group)
        code_in_group.loc[i, "diag2"] = group
mimic3_diag1, mimic3_diag2, mimic3_proc = (
    sorted(mimic3_diag1),
    sorted(mimic3_diag2),
    sorted(mimic3_proc),
)
code_in_group.to_csv("data/mimic3_50/code_hierarchy.csv", sep="\t", index=False)

In [None]:
group_list = mimic3_diag1 + mimic3_diag2 + mimic3_proc
df = pd.DataFrame(group_list, columns=["group"])
df["name"] = df["group"].apply(lambda x: group2def[x])
df.to_csv("data/mimic3_50/group_description.csv", sep="\t", index=False)