<!-- Original Implementation by Gyubok Lee -->
<!-- Refined by Sunjun Kweon on 2024-01-15. -->
<!-- Note: This Jupyter notebook is tailored to the unique requirements of the EHRSQL project. It includes specific modifications and additional adjustments to cater to the dataset and experiment objectives. -->

# OpenAI Model (ChatGPT) Sample Code for EHRSQL: Reliable Text-to-SQL Modeling on Electronic Health Records

<p align="left" float="left">
  <img src="https://github.com/glee4810/ehrsql-2024/raw/master/image/logo.png" height="100" />
</p>

<!-- ## Task Introduction
The goal of the task is to **develop a reliable text-to-SQL system on EHR**. Unlike standard text-to-SQL tasks, this system must handle all types of questions, including answerable and unanswerable ones with respect to the EHR database structure. For answerable questions, the system must accurately generate SQL queries. For unanswerable questions, the system must correctly identify them as such, thereby preventing incorrect SQL predictions for infeasible questions. The range of questions includes answerable queries about MIMIC-IV, covering topics such as patient demographics, vital signs, and specific disease survival rates ([EHRSQL](https://github.com/glee4810/EHRSQL)). Additionally, there are specially designed unanswerable questions intended to challenge the system. Successfully completing this task will result in the creation of a reliable question-answering system for EHRs, significantly improving the flexibility and efficiency of clinical knowledge exploration in hospitals. -->

## Steps of Baseline Code

- [x] Step 0: Prerequisites (OpenAI API)
- [x] Step 1: Clone the GitHub Repository and Install Dependencies
- [x] Step 2: Import Global Packages and Define File Paths
- [x] Step 3: Load Data and Prepare Datasets
- [x] Step 4: Building a predictive model using chatGPT
- [x] Step 5: Submission


## Step 0 : Prerequisites (OpenAI API key)



In [None]:
import getpass
from IPython.display import clear_output

clear_output()
# Please enter your API key
new_api_key = ''
while len(new_api_key) == 0:
    new_api_key = getpass.getpass("Please input your API key: ")
    clear_output()

When submitting your code for code verification, please be sure to submit your OpenAI API key along with your code, like in the sample_submission_chatgpt_api_key.json.

## Step 1: Clone the GitHub Repository and Install Dependencies

Before you begin, make sure you're in the correct directory. If you need to reset the repository directory, remove the existing directory by uncommenting and executing the following lines:

In [None]:
%cd /content
!rm -rf ehrsql-2024

/content


Now, clone the repository and install the required Python packages:

In [None]:
# Cloning the GitHub repository
!git clone -q https://github.com/glee4810/ehrsql-2024.git
%cd ehrsql-2024

# Installing dependencies
!pip install -q tiktoken openai

/content/ehrsql-2024
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m223.4/223.4 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.9/75.9 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.9/76.9 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llmx 0.0.15a0 requires cohere, which is not installed.
tensorflow-probability 0.22.0 requires typing-extensions<4.6.0, but you have typing-extensions 4.9.0 which is incompatible.[0m[31m
[0m

Use the `%load_ext` magic command to automatically reload modules before executing a new line:

In [None]:
%load_ext autoreload
%autoreload 2

## Step 2: Import Global Packages and Define File Paths

After setting up the repository and dependencies, the next step is to import packages that will be used globally throughout this notebook and to define the file paths to our datasets.

In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm

# Directory paths for database, results and scoring program
DB_ID = 'mimic_iv'
BASE_DATA_DIR = 'sample_data'
RESULT_DIR = 'sample_result_submission/'
SCORE_PROGRAM_DIR = 'scoring_program/'

# File paths for the dataset and labels
TABLES_PATH = os.path.join('data', DB_ID, 'tables.json')               # JSON containing database schema
TRAIN_DATA_PATH = os.path.join(BASE_DATA_DIR, 'train', 'data.json')    # JSON file with natural language questions for training data
TRAIN_LABEL_PATH = os.path.join(BASE_DATA_DIR, 'train', 'label.json')  # JSON file with corresponding SQL queries for training data
VALID_DATA_PATH = os.path.join(BASE_DATA_DIR, 'valid', 'data.json')    # JSON file for validation data
DB_PATH = os.path.join('data', DB_ID, f'{DB_ID}.sqlite')               # Database path

In [None]:
# This function loads and processes a database schema from a JSON file.

def load_schema(DATASET_JSON):
    schema_df = pd.read_json(DATASET_JSON)
    schema_df = schema_df.drop(['column_names','table_names'], axis=1)
    schema = []
    f_keys = []
    p_keys = []
    for index, row in schema_df.iterrows():
        tables = row['table_names_original']
        col_names = row['column_names_original']
        col_types = row['column_types']
        foreign_keys = row['foreign_keys']
        primary_keys = row['primary_keys']
        for col, col_type in zip(col_names, col_types):
            index, col_name = col
            if index > -1:
                schema.append([row['db_id'], tables[index], col_name, col_type])
        for primary_key in primary_keys:
            index, column = col_names[primary_key]
            p_keys.append([row['db_id'], tables[index], column])
        for foreign_key in foreign_keys:
            first, second = foreign_key
            first_index, first_column = col_names[first]
            second_index, second_column = col_names[second]
            f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
    db_schema = pd.DataFrame(schema, columns=['Database name', 'Table Name', 'Field Name', 'Type'])
    primary_key = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
    foreign_key = pd.DataFrame(f_keys,
                        columns=['Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
                                 'Second Table Foreign Key'])
    return db_schema, primary_key, foreign_key

# Generates a string representation of foreign key relationships in a MySQL-like format for a specific database.
def find_foreign_keys_MYSQL_like(foreign, db_id):
    df = foreign[foreign['Database name'] == db_id]
    output = "["
    for index, row in df.iterrows():
        output += row['First Table Name'] + '.' + row['First Table Foreign Key'] + " = " + row['Second Table Name'] + '.' + row['Second Table Foreign Key'] + ', '
    output = output[:-2] + "]"
    if len(output)==1:
        output = '[]'
    return output

# Creates a string representation of the fields (columns) in each table of a specific database, formatted in a MySQL-like syntax.
def find_fields_MYSQL_like(db_schema, db_id):
    df = db_schema[db_schema['Database name'] == db_id]
    df = df.groupby('Table Name')
    output = ""
    for name, group in df:
        output += "Table " +name+ ', columns = ['
        for index, row in group.iterrows():
            output += row["Field Name"]+', '
        output = output[:-2]
        output += "]\n"
    return output

# Generates a comprehensive textual prompt describing the database schema, including tables, columns, and foreign key relationships.
def create_schema_prompt(db_id, db_schema, primary_key, foreign_key, is_lower=True):
    prompt = find_fields_MYSQL_like(db_schema, db_id)
    prompt += "Foreign_keys = " + find_foreign_keys_MYSQL_like(foreign_key, db_id)
    if is_lower:
        prompt = prompt.lower()
    return prompt

## Step 3: Load Data and Prepare Datasets

Now that we have our environment and paths set up, the next step is to load the data and prepare it for our model. This involves preprocessing the MIMIC-IV database, reading the data from JSON files, splitting it into training and validation sets, and then initializing our dataset object.

In [None]:
from utils.data_io import read_json as read_data

db_schema, primary_key, foreign_key = load_schema(TABLES_PATH)

train_data = read_data(TRAIN_DATA_PATH)
train_label = read_data(TRAIN_LABEL_PATH)

valid_data = read_data(VALID_DATA_PATH)

table_prompt = create_schema_prompt(DB_ID, db_schema, primary_key, foreign_key)

### Data Statistics

In [None]:
print("Train data:", (len(train_data['data']), len(train_label)))
print("Valid data:", len(valid_data['data']))

Train data: (20, 20)
Valid data: 20


### Data Format

Before proceeding with the model, it is always a good idea to explore the dataset. This includes checking the keys in the dataset, and viewing the first few entries to understand the structure of the data.



In [None]:
# Explore keys and data structure
print(train_data.keys())
print(train_data['version'])
print(train_data['data'][0])

# Explore the label structure
print(train_label.keys())
print(train_label[list(train_label.keys())[0]])

dict_keys(['version', 'data'])
mimiciii_v1.0_sample
{'question': 'had there been a microbiology test done until 35 months ago for patient 46330?', 'id': '2680598b7213c3a13024d289'}
dict_keys(['2680598b7213c3a13024d289', 'bf0911c0b58077c8f1d6f81c', '63878496a698752a63ebef4c', 'd274aaeaebb3bf290111c9e9', 'fc1a576ff1dca02da376ab1f', 'e246f17f4e96c69fa0e1ca12', 'b43caaea4224eabfa85a17c6', '242cd957777ec217877ec636', 'ca1485115687adc483443bf0', 'a957874c4a875205a8924e22', '5123baaf7143d5e7e673608f', 'ba96ef29e76cefcafd105d48', '6134d95b33e1f45d8c4cb3ae', '6d0e8046e4803f7cdc34bc07', '623cb0c0669e682042e8db7f', '4ebfb78b77f705129db7d86b', '0110ce1384fc2de34f7201ea', 'a742ea513b62cac50e64948d', '349c9a170043e0a7eca2ba7e', 'dd1b63bf2b32ea3d905a813b'])
select count(*)>0 from microbiologyevents where microbiologyevents.hadm_id in ( select admissions.hadm_id from admissions where admissions.subject_id = 46330 ) and datetime(microbiologyevents.charttime) <= datetime(current_time,'-35 month')


In [None]:
# Prompt for chatGPT containing meta-data of the database such as columns and foreign keys

print(table_prompt)

table admissions, columns = [row_id, subject_id, hadm_id, admittime, dischtime, admission_type, admission_location, discharge_location, insurance, language, marital_status, age]
table chartevents, columns = [row_id, subject_id, hadm_id, stay_id, itemid, charttime, valuenum, valueuom]
table cost, columns = [row_id, subject_id, hadm_id, event_type, event_id, chargetime, cost]
table d_icd_diagnoses, columns = [row_id, icd_code, long_title]
table d_icd_procedures, columns = [row_id, icd_code, long_title]
table d_items, columns = [row_id, itemid, label, abbreviation, linksto]
table d_labitems, columns = [row_id, itemid, label]
table diagnoses_icd, columns = [row_id, subject_id, hadm_id, icd_code, charttime]
table icustays, columns = [row_id, subject_id, hadm_id, stay_id, first_careunit, last_careunit, intime, outtime]
table inputevents, columns = [row_id, subject_id, hadm_id, stay_id, starttime, itemid, amount]
table labevents, columns = [row_id, subject_id, hadm_id, itemid, charttime, valu

## Step 4: Building a predictive model using chatGPT

In [None]:
# Save your api key into json file
import json

api_path = 'sample_submission_chatgpt_api_key.json'
json_data = {}
json_data['key'] = new_api_key
with open(api_path, 'w') as file:
    json.dump(json_data, file, indent=4)

In [None]:
import os
import re
import json
import tiktoken

import openai
from openai import OpenAI

client = OpenAI(api_key=json_data['key'])

def post_process(answer):
    answer = answer.replace('\n', ' ')
    answer = re.sub('[ ]+', ' ', answer)
    return answer

class Model():
    def __init__(self):
        current_real_dir = os.getcwd()
        # current_real_dir = os.path.dirname(os.path.realpath(__file__))
        target_dir = os.path.join(current_real_dir, 'sample_submission_chatgpt_api_key.json')

        if os.path.isfile(target_dir):
            with open(target_dir, 'rb') as f:
                openai.api_key = json.load(f)['key']
        if not os.path.isfile(target_dir) or openai.api_key == "":
            raise Exception("Error: no API key file found.")

    def ask_chatgpt(self, prompt, model="gpt-3.5-turbo-16k", temperature=0.0):
        response = client.chat.completions.create(
                    model=model,
                    temperature=temperature,
                    messages=prompt
                )
        return response.choices[0].message.content

    def generate(self, input_data):
        """
        Arguments:
            input_data: list of python dictionaries containing 'id' and 'input'
        Returns:
            labels: python dictionary containing sql prediction or 'null' values associated with ids
        """

        labels = {}

        for sample in input_data:
            answer = self.ask_chatgpt(sample['input'])
            labels[sample["id"]] = post_process(answer)

        return labels

NameError: name 'json_data' is not defined

In [None]:
myModel = Model()
data = valid_data["data"]

In [None]:
# System prompt for chatGPT
system_msg = "Given the following SQL tables, your job is to write queries given a user’s request. If you think you cannot get the correct SQL, answer with 'null'."

In [None]:
input_data = []
for sample in tqdm(data):
    sample_dict = {}
    sample_dict['id'] = sample['id']
    conversation = [{"role": "system", "content": system_msg + '\n\n' + table_prompt}]
    user_question_wrapper = lambda question: '\n\n' + f"""NLQ: \"{question}\"\nSQL: \""""
    conversation.append({"role": "user", "content": user_question_wrapper(sample['question'])})
    sample_dict['input'] = conversation
    input_data.append(sample_dict)

100%|██████████| 20/20 [00:00<00:00, 25274.50it/s]


In [None]:
# First message

print(conversation[0]['content'])

Given the following SQL tables, your job is to write queries given a user’s request. If you think you cannot get the correct SQL, answer with 'null'.

table admissions, columns = [row_id, subject_id, hadm_id, admittime, dischtime, admission_type, admission_location, discharge_location, insurance, language, marital_status, age]
table chartevents, columns = [row_id, subject_id, hadm_id, stay_id, itemid, charttime, valuenum, valueuom]
table cost, columns = [row_id, subject_id, hadm_id, event_type, event_id, chargetime, cost]
table d_icd_diagnoses, columns = [row_id, icd_code, long_title]
table d_icd_procedures, columns = [row_id, icd_code, long_title]
table d_items, columns = [row_id, itemid, label, abbreviation, linksto]
table d_labitems, columns = [row_id, itemid, label]
table diagnoses_icd, columns = [row_id, subject_id, hadm_id, icd_code, charttime]
table icustays, columns = [row_id, subject_id, hadm_id, stay_id, first_careunit, last_careunit, intime, outtime]
table inputevents, colum

In [None]:
# Second message

print(conversation[1]['content'])



NLQ: "what are the precautions to take after the packed cell transfusion procedure?"
SQL: "


In [None]:
# Generate answer(SQL) from chatGPT
label_y = myModel.generate(input_data)

Below is how the predicted labels(SQLs) look like

In [None]:
label_y

{'d93058c12f048d80eaae1c99': 'SELECT gender FROM patients WHERE subject_id = 70709"',
 'd3941cc75cbecd1ca1eedc6e': 'SELECT p.drug FROM prescriptions p JOIN admissions a ON p.hadm_id = a.hadm_id JOIN procedures_icd pi ON a.hadm_id = pi.hadm_id WHERE p.subject_id = 14467 AND pi.icd9_code = \'oth extraoc mus-tend op\' AND p.startdate <= DATE_ADD(pi.charttime, INTERVAL 2 DAY) AND p.enddate >= pi.charttime"',
 'dc1c58baffd2987e2ad7b582': 'SELECT d_icd_procedures.short_title, COUNT(*) AS count FROM procedures_icd JOIN admissions ON procedures_icd.hadm_id = admissions.hadm_id JOIN d_icd_procedures ON procedures_icd.icd9_code = d_icd_procedures.icd9_code WHERE admissions.age >= 20 AND admissions.age < 30 AND EXTRACT(YEAR FROM admissions.admittime) = EXTRACT(YEAR FROM CURRENT_DATE) GROUP BY d_icd_procedures.short_title ORDER BY count DESC LIMIT 4;"',
 '5ff51cd11c92b00c656f621d': 'SELECT COUNT(DISTINCT admissions.subject_id) FROM admissions JOIN diagnoses_icd ON admissions.hadm_id = diagnoses_ic

In [None]:
from utils.data_io import write_json as write_label

# Save the filtered predictions to a JSON file
os.makedirs(RESULT_DIR, exist_ok=True)
SCORING_OUTPUT_DIR = os.path.join(RESULT_DIR, 'prediction.json')
write_label(SCORING_OUTPUT_DIR, label_y)

# Verify the file creation
print("Listing files in RESULT_DIR:")
!ls {RESULT_DIR}

Listing files in RESULT_DIR:
prediction.json


## Step 5: Submission

In this final step, we'll prepare and submit our results to the Codabench competition.

In [None]:
# Change to the directory containing the prediction file
%cd {RESULT_DIR}

# Compress the prediction.json file into a ZIP archive
!zip predictions.zip prediction.json

/content/ehrsql-2024/sample_result_submission
  adding: prediction.json (deflated 65%)


- Submission File: Ensure that the `predictions.zip` file contains only the `prediction.json` file. This ZIP archive is the required format for submission to Codabench.

- Submitting on Codabench: Navigate to the Codabench competition page and go to the **My Submissions** tab. Upload the `predictions.zip` file following the provided instructions. Make sure to adhere to any guidelines or submission requirements detailed on the competition page.