In [8]:
import pickle

with open('/data/pj20/exp_data/icd9cm_icd9proc/drugrec_dataset_umls.pkl', 'rb') as f:
    sample_dataset = pickle.load(f)

In [9]:
import csv

condition_mapping_file = "../../resources/ICD9CM.csv"
procedure_mapping_file = "../../resources/ICD9PROC.csv"
drug_file = "../../resources/ATC.csv"

condition_dict = {}
with open(condition_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        condition_dict[row['code'].replace('.', '')] = row['name'].lower()

procedure_dict = {}
with open(procedure_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        procedure_dict[row['code'].replace('.', '')] = row['name'].lower()

drug_dict = {}
with open(drug_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if row['level'] == '3.0':
            drug_dict[row['code'].replace('.', '')] = row['name'].lower()


In [10]:
def flatten(lst):
    result = []
    for item in lst:
        if isinstance(item, list):
            result.extend(flatten(item))
        else:
            result.append(item)
    return result

In [11]:
condition_dict_sample = {}
procedure_dict_sample = {}

for i in range(101):
    sample = sample_dataset[i]
    for condition in flatten(sample['conditions']):
        if condition not in condition_dict_sample:
            condition_dict_sample[condition] = condition_dict[condition]
    for procedure in flatten(sample['procedures']):
        if procedure not in procedure_dict_sample:
            try:
                procedure_dict_sample[procedure] = procedure_dict[procedure]
            except:
                procedure_dict_sample[procedure[:-1]] = procedure_dict[procedure[:-1]]

In [12]:
import re 
from ChatGPT import ChatGPT
import json

def extract_data_in_brackets(input_string):
    pattern = r"\[(.*?)\]"
    matches = re.findall(pattern, input_string)
    return matches

def divide_text(long_text, max_len=800):
    sub_texts = []
    start_idx = 0
    while start_idx < len(long_text):
        end_idx = start_idx + max_len
        sub_text = long_text[start_idx:end_idx]
        sub_texts.append(sub_text)
        start_idx = end_idx
    return sub_texts

In [13]:
from ChatGPT import ChatGPT
import json

def graph_gen(term: str, mode: str):
    if mode == "condition":
        example = \
        """
        Example:
        prompt: systemic lupus erythematosus
        updates: [[systemic lupus erythematosus, is an, autoimmune condition], [systemic lupus erythematosus, may cause, nephritis], [anti-nuclear antigen, is a test for, systemic lupus erythematosus], [systemic lupus erythematosus, is treated with, steroids], [methylprednisolone, is a, steroid]]
        """
    elif mode == "procedure":
        example = \
        """
        Example:
        prompt: endoscopy
        updates: [[endoscopy, is a, medical procedure], [endoscopy, used for, diagnosis], [endoscopic biopsy, is a type of, endoscopy], [endoscopic biopsy, can detect, ulcers]]
        """
    elif mode == "drug":
        example = \
        """
        Example:
        prompt: iobenzamic acid
        updates: [[iobenzamic acid, is a, drug], [iobenzamic acid, may have, side effects], [side effects, can include, nausea], [iobenzamic acid, used as, X-ray contrast agent], [iobenzamic acid, formula, C16H13I3N2O3]]
        """
    chatgpt = ChatGPT()
    response = chatgpt.chat(
        f"""
            Given a prompt (a medical condition/procedure/drug), extrapolate as many relationships as possible of it and provide a list of updates.
            The relationships should be helpful for healthcare prediction (e.g., drug recommendation, mortality prediction, readmission prediction …)
            Each update should be exactly in format of [ENTITY 1, RELATIONSHIP, ENTITY 2]. The relationship is directed, so the order matters.
            Both ENTITY 1 and ENTITY 2 should be noun.
            Any element in [ENTITY 1, RELATIONSHIP, ENTITY 2] should be conclusive, make it as short as possible.
            Do this in both breadth and depth. Expand [ENTITY 1, RELATIONSHIP, ENTITY 2] until the size reaches 100.

            {example}

            prompt: {term}
            updates:
        """
        )
    json_string = str(response)
    json_data = json.loads(json_string)

    triples = extract_data_in_brackets(json_data['content'])
    outstr = ""
    for triple in triples:
        outstr += triple.replace('[', '').replace(']', '').replace(', ', '\t') + '\n'

    return outstr

In [17]:
from tqdm import tqdm
import os

for key in tqdm(condition_dict_sample.keys()):
    file = f'../../graphs/condition/ICD9CM_base_gpt/{key}.txt'
    if os.path.exists(file):
        with open(file=file, mode="r", encoding='utf-8') as f:
            prev_triples = f.read()
        if len(prev_triples.split('\n')) < 50:
            outstr = graph_gen(term=condition_dict_sample[key], mode="condition")
            outfile = open(file=file, mode='w', encoding='utf-8')
            outstr = prev_triples + outstr
            outfile.write(outstr)
    else:
        outstr = graph_gen(term=condition_dict_sample[key], mode="condition")
        outfile = open(file=file, mode='w', encoding='utf-8')
        outstr = outstr
        outfile.write(outstr)

100%|██████████| 449/449 [1:25:51<00:00, 11.47s/it]
