<!-- Original Implementation by Gyubok Lee -->
<!-- Refined by Sunjun Kweon on 2024-01-15. -->
<!-- Note: This Jupyter notebook is tailored to the unique requirements of the EHRSQL project. It includes specific modifications and additional adjustments to cater to the dataset and experiment objectives. -->

# OpenAI Model (ChatGPT) Sample Code for EHRSQL: Reliable Text-to-SQL Modeling on Electronic Health Records

<p align="left" float="left">
  <img src="https://github.com/glee4810/ehrsql-2024/raw/master/image/logo.png" height="100" />
</p>

<!-- ## Task Introduction
The goal of the task is to **develop a reliable text-to-SQL system on EHR**. Unlike standard text-to-SQL tasks, this system must handle all types of questions, including answerable and unanswerable ones with respect to the EHR database structure. For answerable questions, the system must accurately generate SQL queries. For unanswerable questions, the system must correctly identify them as such, thereby preventing incorrect SQL predictions for infeasible questions. The range of questions includes answerable queries about MIMIC-IV, covering topics such as patient demographics, vital signs, and specific disease survival rates ([EHRSQL](https://github.com/glee4810/EHRSQL)). Additionally, there are specially designed unanswerable questions intended to challenge the system. Successfully completing this task will result in the creation of a reliable question-answering system for EHRs, significantly improving the flexibility and efficiency of clinical knowledge exploration in hospitals. -->

## Steps of Baseline Code

- [x] Step 0: Prerequisites (OpenAI API)
- [x] Step 1: Clone the GitHub Repository and Install Dependencies
- [x] Step 2: Import Global Packages and Define File Paths
- [x] Step 3: Load Data and Prepare Datasets
- [x] Step 4: Building a predictive model using chatGPT
- [x] Step 5: Submission


## Step 0 : Prerequisites (OpenAI API key)



In [1]:
import getpass
from IPython.display import clear_output

clear_output()
# Please enter your API key
new_api_key = '<your openai api key>'
while len(new_api_key) == 0:
    new_api_key = getpass.getpass("Please input your API key: ")
    clear_output()

When submitting your code for code verification, please be sure to submit your OpenAI API key along with your code, like in the sample_submission_chatgpt_api_key.json.

## Step 1: Clone the GitHub Repository and Install Dependencies

Before you begin, make sure you're in the correct directory. If you need to reset the repository directory, remove the existing directory by uncommenting and executing the following lines:

Now, clone the repository and install the required Python packages:

Use the `%load_ext` magic command to automatically reload modules before executing a new line:

In [30]:
%load_ext autoreload
%autoreload 2

## Step 2: Import Global Packages and Define File Paths

After setting up the repository and dependencies, the next step is to import packages that will be used globally throughout this notebook and to define the file paths to our datasets.

In [2]:
import os
import json
import pandas as pd

# Directory paths for database, results and scoring program
DB_ID = 'mimic_iv'
BASE_DATA_DIR = 'data/mimic_iv'
RESULT_DIR = 'sample_result_submission/'
SCORE_PROGRAM_DIR = 'scoring_program/'

# File paths for the dataset and labels
TABLES_PATH = os.path.join('data', DB_ID, 'tables.json')               # JSON containing database schema
TRAIN_DATA_PATH = os.path.join(BASE_DATA_DIR, 'train', 'data.json')    # JSON file with natural language questions for training data
TRAIN_LABEL_PATH = os.path.join(BASE_DATA_DIR, 'train', 'label.json')  # JSON file with corresponding SQL queries for training data
TRAIN_ANS_PATH = os.path.join(BASE_DATA_DIR, 'train', 'answer.json')   # JSON file with corresponding answers for training data
VALID_DATA_PATH = os.path.join(BASE_DATA_DIR, 'valid', 'data.json')    # JSON file for validation data
DB_PATH = os.path.join('data', DB_ID, f'{DB_ID}.sqlite')               # Database path

TEST_DATA_PATH = os.path.join(BASE_DATA_DIR, 'test', 'data.json')

In [3]:
# This function loads and processes a database schema from a JSON file.

def load_schema(DATASET_JSON): # (TABLES_PATH)
    schema_df = pd.read_json(DATASET_JSON)
    schema_df = schema_df.drop(['column_names','table_names'], axis=1)
    schema = []
    f_keys = []
    p_keys = []
    for index, row in schema_df.iterrows():
        tables = row['table_names_original']
        col_names = row['column_names_original']
        col_types = row['column_types']
        foreign_keys = row['foreign_keys']
        primary_keys = row['primary_keys']
        for col, col_type in zip(col_names, col_types):
            index, col_name = col
            if index > -1:
                schema.append([row['db_id'], tables[index], col_name, col_type])
                
        for primary_key in primary_keys:
            index, column = col_names[primary_key]
            p_keys.append([row['db_id'], tables[index], column])
            
        for foreign_key in foreign_keys:
            first, second = foreign_key
            first_index, first_column = col_names[first]
            second_index, second_column = col_names[second]
            f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
    db_schema = pd.DataFrame(schema, columns=['Database name', 'Table Name', 'Field Name', 'Type'])
    primary_key = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
    foreign_key = pd.DataFrame(f_keys,
                        columns=['Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
                                 'Second Table Foreign Key'])
    return db_schema, primary_key, foreign_key

# Generates a string representation of foreign key relationships in a MySQL-like format for a specific database.
def find_foreign_keys_MYSQL_like(foreign, db_id):
    df = foreign[foreign['Database name'] == db_id]
    output = "["
    for index, row in df.iterrows():
        output += row['First Table Name'] + '.' + row['First Table Foreign Key'] + " = " + row['Second Table Name'] + '.' + row['Second Table Foreign Key'] + ', '
    output = output[:-2] + "]"
    if len(output)==1:
        output = '[]'
    return output

# Creates a string representation of the fields (columns) in each table of a specific database, formatted in a MySQL-like syntax.
def find_fields_MYSQL_like(db_schema, db_id):
    df = db_schema[db_schema['Database name'] == db_id]
    df = df.groupby('Table Name')
    output = ""
    for name, group in df:
        output += "Table " +name+ ', columns = ['
        for index, row in group.iterrows():
            output += row["Field Name"]+', '
        output = output[:-2]
        output += "]\n"
    return output

# Generates a comprehensive textual prompt describing the database schema, including tables, columns, and foreign key relationships.
def create_schema_prompt(db_id, db_schema, primary_key, foreign_key, is_lower=True):
    prompt = find_fields_MYSQL_like(db_schema, db_id)
    prompt += "Foreign_keys = " + find_foreign_keys_MYSQL_like(foreign_key, db_id)
    if is_lower:
        prompt = prompt.lower()
    return prompt
    

## Step 3: Load Data and Prepare Datasets

Now that we have our environment and paths set up, the next step is to load the data and prepare it for our model. This involves preprocessing the MIMIC-IV database, reading the data from JSON files, splitting it into training and validation sets, and then initializing our dataset object.

In [4]:
from utils.data_io import read_json as read_data

db_schema, primary_key, foreign_key = load_schema(TABLES_PATH)

train_data = read_data(TRAIN_DATA_PATH)
train_label = read_data(TRAIN_LABEL_PATH)

valid_data = read_data(VALID_DATA_PATH)

table_prompt = create_schema_prompt(DB_ID, db_schema, primary_key, foreign_key)

#test
test_data = read_data(TEST_DATA_PATH)
print(table_prompt)


table admissions, columns = [row_id, subject_id, hadm_id, admittime, dischtime, admission_type, admission_location, discharge_location, insurance, language, marital_status, age]
table chartevents, columns = [row_id, subject_id, hadm_id, stay_id, itemid, charttime, valuenum, valueuom]
table cost, columns = [row_id, subject_id, hadm_id, event_type, event_id, chargetime, cost]
table d_icd_diagnoses, columns = [row_id, icd_code, long_title]
table d_icd_procedures, columns = [row_id, icd_code, long_title]
table d_items, columns = [row_id, itemid, label, abbreviation, linksto]
table d_labitems, columns = [row_id, itemid, label]
table diagnoses_icd, columns = [row_id, subject_id, hadm_id, icd_code, charttime]
table icustays, columns = [row_id, subject_id, hadm_id, stay_id, first_careunit, last_careunit, intime, outtime]
table inputevents, columns = [row_id, subject_id, hadm_id, stay_id, starttime, itemid, amount]
table labevents, columns = [row_id, subject_id, hadm_id, itemid, charttime, valu

In [5]:
import tiktoken
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
print(f"Prompt 길이: {len(encoding.encode(table_prompt))}")
#Prompt 길이: 751토큰   << 프롬프트 제한 4000토큰


Prompt 길이: 751


### Data Statistics

In [6]:
print("Train data:", (len(train_data['data']), len(train_label)))
print("Valid data:", len(valid_data['data']))
print("Test data:", len(test_data['data']))

Train data: (5124, 5124)
Valid data: 1163
Test data: 1167


### Data Format

Before proceeding with the model, it is always a good idea to explore the dataset. This includes checking the keys in the dataset, and viewing the first few entries to understand the structure of the data.



In [16]:
# Explore keys and data structure
# print(train_data.keys())
# print(train_data['version'])
# print(train_data['data'][0])

# # Explore the label structure
# print(train_label.keys())
# print(train_label[list(train_label.keys())[0]])

In [17]:
# Prompt for chatGPT containing meta-data of the database such as columns and foreign keys


In [7]:
from scoring_program.scoring_utils import execute_all
from scoring_program.postprocessing import post_process_sql

pred_dict2 = {'3b9849548e56c59f768d5447': "SELECT MIN(chartevents.valuenum) FROM chartevents WHERE chartevents.stay_id IN ( SELECT icustays.stay_id FROM icustays WHERE icustays.hadm_id IN ( SELECT admissions.hadm_id FROM admissions WHERE admissions.subject_id = 10021118 ) AND icustays.outtime IS NOT NULL ORDER BY icustays.intime ASC LIMIT 1 ) AND chartevents.itemid IN ( SELECT d_items.itemid FROM d_items WHERE d_items.label = 'respiratory rate' AND d_items.linksto = 'chartevents' )"}
pred_dict2 = {id_: post_process_sql(pred_dict2[id_]) for id_ in pred_dict2} # necessary before executing SQL (it replaces placeholders with their actual values)
pred_result2 = execute_all(pred_dict2, DB_PATH, tag='pred') # "execute_all" executes all queries and processes each answer to be "a list of lists".
pred_result2

{'3b9849548e56c59f768d5447': 'error_pred'}

## Step 4: Building a predictive model using chatGPT

In [8]:
# Save your api key into json file
import json

api_path = 'sample_submission_chatgpt_api_key.json'
json_data = {}
json_data['key'] = new_api_key
with open(api_path, 'w') as file:
    json.dump(json_data, file, indent=4)

In [20]:
#!pip install -U openai

In [9]:
import os
import re
import json
import tiktoken
from tqdm import tqdm
from openai import OpenAI
import openai

client = OpenAI(api_key=json_data['key'])

def post_process(answer):
    answer = answer.replace('\n', ' ')
    answer = re.sub('[ ]+', ' ', answer)
    return answer

class Model():
    def __init__(self):
        current_real_dir = os.getcwd()
        # current_real_dir = os.path.dirname(os.path.realpath(__file__))
        target_dir = os.path.join(current_real_dir, 'sample_submission_chatgpt_api_key.json')
        if os.path.isfile(target_dir):
            with open(target_dir, 'rb') as f:
                openai.api_key = json.load(f)['key']
        if not os.path.isfile(target_dir) or openai.api_key == "":
            raise Exception("Error: no API key file found.")
        
    #ft:gpt-3.5-turbo-1106:personal::95amfIJ8 # gpt3.5 only sql generate
    #ft:gpt-3.5-turbo-1106:personal::94BvwLbf  # gpt3.5 null+ sql generate
    #ft:gpt-3.5-turbo-0125:personal::96BuBjlL # (상렬님 api) sql generate
    # 파인튜닝한 모델로 지정
    def ask_chatgpt(self, prompt, model="ft:gpt-3.5-turbo-0125:personal::96BuBjlL", temperature=0.0):
        response = client.chat.completions.create(
                    model=model,
                    temperature=temperature,
                    messages=prompt,
                    logprobs=True,
                    top_logprobs=5
                )
        
        return (response.choices[0].message.content, response.choices[0].logprobs.content)
     
    def generate(self, input_data, label_T5):
        labels = {} 
    
        for sample in tqdm(input_data, desc="Processing", unit="query"):
            answer = self.ask_chatgpt(sample['input'])
            labels[sample['id']] = answer
            #processed_answer = post_process(answer)
            #labels[sample_id] = processed_answer

        return labels

In [10]:
myModel = Model()
#data = valid_data["data"] # why valid?
data = test_data["data"]
data

[{'id': '282f008dd8dfb8f4a1dd6999',
  'question': 'Tell me the name of the organism that was detected in the last urine test of patient 10027602 on their first hospital visit?'},
 {'id': '47fd000ef0b1033a8aabfac8',
  'question': 'Show me the top five most frequently prescribed medications since 2100.'},
 {'id': '7f07c59357750ac2b84e5221',
  'question': 'What are the five frequently taken specimens for patients that received extirpation of matter from lung lingula, via natural or artificial opening endoscopic previously within 2 months?'},
 {'id': 'c83990de3e5d2d218c83528a',
  'question': "When does patient 8016's influenza quarantine end?"},
 {'id': 'f430a02bd152c617a86cd2c3',
  'question': 'How much is patient 10038933 needed to pay for the cost of their hospital visit in 2100?'},
 {'id': '34e18d3b9b2f055f37a2d04d',
  'question': 'Call a medical supply company to order more icu equipment.'},
 {'id': 'dfc058cc8e8d30a29c0177b9',
  'question': 'After being diagnosed with thrombocytopenia

In [11]:
# sql gen+ null classify
system_msg = "You are 'SQLgpt', a sophisticated AI designed to transform user questions into precise SQL queries. Focus on crafting SQL queries that accurately reflect the intentions of the questions. Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single ' nor double \"). The integrity and accuracy of your responses are paramount. When faced with ambiguity, lack of context, or a query that exceeds the limits of the available data, you must opt for a 'null' response rather than producing a potentially incorrect SQL. Your judgment in these situations is crucial, as we prioritize accuracy and the prevention of misinformation above all."
system_msg = "You are 'SQLgpt', an advanced AI designed to convert user questions into accurate SQL queries. It is crucial to accurately reflect the user's intentions and strictly adhere to the standard SQL format guidelines. The generated SQL queries must adhere to the standard SQL format and should not be enclosed in quotes (neither single ' nor double \"). The integrity and accuracy of responses are paramount. In cases of uncertainty, insufficient context, or when the query exceeds the available data limits, it is essential to respond with 'null' instead of potentially generating inaccurate SQL. Your judgment in such situations is critical, emphasizing the importance of accuracy and error prevention. Responding with 'null' in doubtful cases is highly important"
# + schema
system_msg = "You are 'SQLgpt', an advanced AI designed to convert user questions into accurate SQL queries. Your goal is to generate SQL queries that accurately represent the user's intent while strictly adhering to standard SQL format guidelines. Remember, generated SQL queries must not contain quotes, neither single (' ') nor double (\"). The integrity and accuracy of your responses are critical. In situations of uncertainty, insufficient context, or when a query might exceed the data available, opting for 'null' instead of generating a potentially incorrect SQL is imperative. Your judgment is crucial in preventing errors and ensuring accuracy. Always choose 'null' in doubtful situations to avoid generating inaccurate SQL. Additionally, be aware of the database schema and ensure your SQL queries do not go beyond the schema's scope or generate incorrect SQL based on the schema's limitations."
# system_msg = "You are 'SQLgpt', a sophisticated AI designed to transform user questions into precise SQL queries. Focus on crafting SQL queries that accurately reflect the intentions of the questions, with a special emphasis on utilizing the following table structures and their columns effectively:\n\n" 
# system_msg2 = "Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single ' nor double \"). The integrity and accuracy of your responses are paramount. When faced with ambiguity, lack of context, or a query that exceeds the limits of the available data, you must opt for a 'null' response rather than producing a potentially incorrect SQL. Your judgment in these situations is crucial, as we prioritize accuracy and the prevention of misinformation above all. Make sure to reference the table structures provided to maximize the relevance and accuracy of your SQL queries."
# #only sql generate
system_msg = "You are 'SQLgpt', an AI designed to convert natural language questions into their corresponding SQL queries.Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single ' nor double \"). Your primary goal is to accurately generate the exact SQL query for each question presented to you."



In [12]:
input_data = []
for sample in data:
    sample_dict = {}
    sample_dict['id'] = sample['id']
    
    #conversation = [{"role": "system", "content": system_msg + table_prompt+'\n\n'+system_msg2}]
    conversation = [{"role": "system", "content": system_msg +'\n\n'}]

    #conversation = [{"role": "system", "content": system_msg}]
    user_question_wrapper = lambda question: '\n\n' + f"""NLQ: \"{question}\"\nSQL: """
    conversation.append({"role": "user", "content": user_question_wrapper(sample['question'])})
    sample_dict['input'] = conversation
    input_data.append(sample_dict)
print(input_data[0])

{'id': '282f008dd8dfb8f4a1dd6999', 'input': [{'role': 'system', 'content': 'You are \'SQLgpt\', an AI designed to convert natural language questions into their corresponding SQL queries.Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single \' nor double "). Your primary goal is to accurately generate the exact SQL query for each question presented to you.\n\n'}, {'role': 'user', 'content': '\n\nNLQ: "Tell me the name of the organism that was detected in the last urine test of patient 10027602 on their first hospital visit?"\nSQL: '}]}


In [13]:
# First message

print(conversation[0]['content'])

You are 'SQLgpt', an AI designed to convert natural language questions into their corresponding SQL queries.Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single ' nor double "). Your primary goal is to accurately generate the exact SQL query for each question presented to you.




In [14]:
print(input_data)

[{'id': '282f008dd8dfb8f4a1dd6999', 'input': [{'role': 'system', 'content': 'You are \'SQLgpt\', an AI designed to convert natural language questions into their corresponding SQL queries.Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single \' nor double "). Your primary goal is to accurately generate the exact SQL query for each question presented to you.\n\n'}, {'role': 'user', 'content': '\n\nNLQ: "Tell me the name of the organism that was detected in the last urine test of patient 10027602 on their first hospital visit?"\nSQL: '}]}, {'id': '47fd000ef0b1033a8aabfac8', 'input': [{'role': 'system', 'content': 'You are \'SQLgpt\', an AI designed to convert natural language questions into their corresponding SQL queries.Importantly, the generated SQL queries should adhere to the standard SQL format and not be enclosed in quotes (neither single \' nor double "). Your primary goal is to accurately generate the exact S

In [15]:
data[0]

{'id': '282f008dd8dfb8f4a1dd6999',
 'question': 'Tell me the name of the organism that was detected in the last urine test of patient 10027602 on their first hospital visit?'}

In [17]:
# Generate answer(SQL) from chatGPT
"""T5 result"""
label_y = myModel.generate(input_data,data)

"""T5 results ensemble."""
#label_y = myModel.generate(input_data,label_T5)
#label_y = myModel.generate(input_data,final_result)

Processing: 100%|██████████| 1167/1167 [44:48<00:00,  2.30s/query] 


'T5 results ensemble.'

In [61]:
label_y['d084d1f3c277e6827087bb44'][1][4]

ChatCompletionTokenLogprob(token='0', bytes=[48], logprob=0.0, top_logprobs=[TopLogprob(token='0', bytes=[48], logprob=0.0), TopLogprob(token='1', bytes=[49], logprob=-21.326208), TopLogprob(token='00', bytes=[48, 48], logprob=-26.679605), TopLogprob(token='5', bytes=[53], logprob=-27.61255), TopLogprob(token='10', bytes=[49, 48], logprob=-28.367544)])

Below is how the predicted labels(SQLs) look like

In [18]:
import pickle

with open('log_probability_test_new.pickle', 'wb') as f:
    pickle.dump(label_y, f, pickle.HIGHEST_PROTOCOL)

In [42]:
def top_prob(prob):
    if 'top_logprobs' in prob:
        prob['top_logprobs'] = [dict(x) for x in prob['top_logprobs']]

    return prob

final_data={}
for key in label_y:
    token_prob = [top_prob(dict(x)) for x in label_y[key][1]]
    final_data[key] = [label_y[key][0], token_prob]


In [43]:
import json
output_filename = 'log_probability_final.json'
# JSON 파일로 저장
with open(output_filename, 'w') as outfile:
    json.dump(final_data, outfile)

print(f"File saved: {output_filename}")

File saved: log_probability_final.json
