In [1]:
db_schema = """
CREATE TABLE PATIENTS 
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    GENDER VARCHAR(5) NOT NULL,
    DOB TIMESTAMP(0) NOT NULL,
    DOD TIMESTAMP(0),
    CONSTRAINT pat_subid_unique UNIQUE (SUBJECT_ID),
    CONSTRAINT pat_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE ADMISSIONS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    ADMITTIME TIMESTAMP(0) NOT NULL,
    DISCHTIME TIMESTAMP(0),
    ADMISSION_TYPE VARCHAR(50) NOT NULL,
    ADMISSION_LOCATION VARCHAR(50) NOT NULL,
    DISCHARGE_LOCATION VARCHAR(50),
    INSURANCE VARCHAR(255) NOT NULL,
    LANGUAGE VARCHAR(10),
    MARITAL_STATUS VARCHAR(50),
    AGE INT NOT NULL,
    CONSTRAINT adm_hadmid_unique UNIQUE (HADM_ID),
    CONSTRAINT adm_rowid_pk PRIMARY KEY (ROW_ID)    
);

CREATE TABLE D_ICD_DIAGNOSES
(
    ROW_ID INT NOT NULL,
    ICD_CODE VARCHAR(10) NOT NULL,
    LONG_TITLE VARCHAR(255) NOT NULL,
    CONSTRAINT d_icd_diag_code_unique UNIQUE (ICD_CODE),
    CONSTRAINT d_icd_diag_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE D_ICD_PROCEDURES 
(
    ROW_ID INT NOT NULL,
    ICD_CODE VARCHAR(10) NOT NULL,
    LONG_TITLE VARCHAR(255) NOT NULL,
    CONSTRAINT d_icd_proc_code_unique UNIQUE (ICD_CODE),
    CONSTRAINT d_icd_proc_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE D_LABITEMS 
(
    ROW_ID INT NOT NULL,
    ITEMID INT NOT NULL,
    LABEL VARCHAR(200),
    CONSTRAINT dlabitems_itemid_unique UNIQUE (ITEMID),
    CONSTRAINT dlabitems_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE D_ITEMS 
(
    ROW_ID INT NOT NULL,
    ITEMID INT NOT NULL,
    LABEL VARCHAR(200) NOT NULL,
    ABBREVIATION VARCHAR(200) NOT NULL,
    LINKSTO VARCHAR(50) NOT NULL,
    CONSTRAINT ditems_itemid_unique UNIQUE (ITEMID),
    CONSTRAINT ditems_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE DIAGNOSES_ICD
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    ICD_CODE VARCHAR(10) NOT NULL,
    CHARTTIME TIMESTAMP(0) NOT NULL,
    CONSTRAINT diagnosesicd_rowid_unique UNIQUE (ROW_ID)
    CONSTRAINT diagnosesicd_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE PROCEDURES_ICD
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    ICD_CODE VARCHAR(10) NOT NULL,
    CHARTTIME TIMESTAMP(0) NOT NULL,
    CONSTRAINT proceduresicd_rowid_unique UNIQUE (ROW_ID),
    CONSTRAINT proceduresicd_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE LABEVENTS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    ITEMID INT NOT NULL,
    CHARTTIME TIMESTAMP(0),
    VALUENUM DOUBLE PRECISION,
    VALUEUOM VARCHAR(20),
    CONSTRAINT labevents_rowid_unuque UNIQUE (ROW_ID),    
    CONSTRAINT labevents_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE PRESCRIPTIONS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    STARTTIME TIMESTAMP(0) NOT NULL,
    STOPTIME TIMESTAMP(0),
    DRUG VARCHAR(255) NOT NULL,
    DOSE_VAL_RX VARCHAR(100) NOT NULL,
    DOSE_UNIT_RX VARCHAR(50) NOT NULL,
    ROUTE VARCHAR(50) NOT NULL,
    CONSTRAINT prescription_rowid_unuque UNIQUE (ROW_ID),    
    CONSTRAINT prescription_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE COST
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    EVENT_TYPE VARCHAR(20) NOT NULL,
    EVENT_ID INT NOT NULL,
    CHARGETIME TIMESTAMP(0) NOT NULL,
    COST DOUBLE PRECISION NOT NULL,
    CONSTRAINT cost_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE CHARTEVENTS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    STAY_ID INT NOT NULL,
    ITEMID INT NOT NULL,
    CHARTTIME TIMESTAMP(0) NOT NULL,
    VALUENUM DOUBLE PRECISION,
    VALUEUOM VARCHAR(50),
    CONSTRAINT chartevents_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE INPUTEVENTS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    STAY_ID INT NOT NULL,
    STARTTIME TIMESTAMP(0) NOT NULL,
    ITEMID INT NOT NULL,
    AMOUNT DOUBLE PRECISION,
    CONSTRAINT inputevents_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE OUTPUTEVENTS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    STAY_ID INT NOT NULL,
    CHARTTIME TIMESTAMP(0) NOT NULL,
    ITEMID INT NOT NULL,
    VALUE DOUBLE PRECISION,
    CONSTRAINT outputevents_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE MICROBIOLOGYEVENTS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    CHARTTIME TIMESTAMP(0) NOT NULL,
    SPEC_TYPE_DESC VARCHAR(100),
    TEST_NAME VARCHAR(100),
    ORG_NAME VARCHAR(100),
    CONSTRAINT micro_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE ICUSTAYS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    STAY_ID INT NOT NULL,    
    FIRST_CAREUNIT VARCHAR(20) NOT NULL,
    LAST_CAREUNIT VARCHAR(20) NOT NULL,
    INTIME TIMESTAMP(0) NOT NULL,
    OUTTIME TIMESTAMP(0),
    CONSTRAINT icustay_stayid_unique UNIQUE (STAY_ID),
    CONSTRAINT icustay_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE TRANSFERS
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT NOT NULL,
    TRANSFER_ID INT NOT NULL,
    EVENTTYPE VARCHAR(20) NOT NULL,
    CAREUNIT VARCHAR(20),
    INTIME TIMESTAMP(0) NOT NULL,
    OUTTIME TIMESTAMP(0),
    CONSTRAINT transfers_rowid_pk PRIMARY KEY (ROW_ID)
);

CREATE TABLE TB_CXR
(
    ROW_ID INT NOT NULL,
    SUBJECT_ID INT NOT NULL,
    HADM_ID INT,
    STUDY_ID INT NOT NULL,
    IMAGE_ID INT NOT NULL,
    STUDYDATETIME TIMESTAMP(0) NOT NULL,
    VIEWPOSITION VARCHAR(20) NOT NULL,
    STUDYORDER INT NOT NULL,
    CONSTRAINT tb_cxr_rowid_pk PRIMARY KEY (ROW_ID)
);
"""

In [2]:
# Define the prompt template (as before)
prompt_template_v1 = """
Given the following database schema:

{schema}

And the following question:

{question}

Please generate an SQL query to answer this question. Note that there is a special function available called func_vqa which can perform visual question answering. It takes a question as input and a column of a database table (with image IDs), and returns true or false.

The query should use the func_vqa function when necessary to analyze chest X-ray images. The function can be used in various parts of the SQL query, including but not limited to the SELECT, WHERE, or HAVING clauses, depending on the requirements of the question.

Please provide the SQL query that answers the given question, incorporating the func_vqa function where appropriate to analyze the chest X-ray images. Only use the tables and attributes from the provided database schema.
"""

In [3]:
prompt_template_v2 = """You are a SQLite expert. Given an input question, create a single syntactically correct SQLite query to run. Only output the SQL Query.

Note that there is a special function available inside SQLite called func_vqa which can perform visual question answering. It takes a question as input and the "study_id" column variable, and returns a list of values. These values can correspond to anatomical findings or correspond to "yes"/"no" in case of a binary question. The question needs to be provided using double quotes. Your created query should use the func_vqa function when necessary to analyze chest X-ray images. The function can be used in various parts of the SQL query, including but not limited to the SELECT, WHERE, or HAVING clauses, depending on the requirements of the question.

Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Be careful to only query columns that exist in the provided tables, and note which column belongs to which table. If the question involves "today", use date('now') to get the current date.

Use the following tables:

{schema}

Question: {question}
SQLQuery: """

In [4]:
prompt_template_v3 = """You are a SQLite expert. Given an input question, create a single syntactically correct SQLite query to run. Only output the SQL Query.
Note that there is a special function available inside SQLite called func_vqa which can perform visual question answering. It takes a question as input and the study_id column variable, and returns true or false.
Your created query should use the func_vqa function when necessary to analyze chest X-ray images. The function can be used in various parts of the SQL query, including but not limited to the SELECT, WHERE, or HAVING clauses, depending on the requirements of the question.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today". Pay attention to include the func_vqa when necessary.

Here is one example of a correct Text-to-SQL translation:

Question: did patient 10405915 have a chest x-ray showing any abnormalities in the left hemidiaphragm within the same month after being diagnosed with unspecified pleural effusion since 2 year ago?"
SQLQuery: select count(*)>0 from ( select diagnoses_icd.charttime from diagnoses_icd where diagnoses_icd.icd_code = ( select d_icd_diagnoses.icd_code from d_icd_diagnoses where d_icd_diagnoses.long_title = 'unspecified pleural effusion' ) and diagnoses_icd.hadm_id in ( select admissions.hadm_id from admissions where admissions.subject_id = 10405915 ) and datetime(diagnoses_icd.charttime) >= datetime(current_time,'-2 year') ) as t1 join ( select tb_cxr.studydatetime from tb_cxr where tb_cxr.study_id in ( select distinct t2.study_id from ( select distinct tb_cxr.study_id from tb_cxr where tb_cxr.study_id in ( select distinct tb_cxr.study_id from tb_cxr where tb_cxr.subject_id = 10405915 and datetime(tb_cxr.studydatetime) >= datetime(current_time,'-2 year') ) ) as t2 where func_vqa("is any abnormality shown in the left hemidiaphragm on a chest x-ray?", t2.study_id) = true ) ) as t3 where t1.charttime < t3.studydatetime and datetime(t1.charttime,'start of month') = datetime(t3.studydatetime,'start of month')"

Use the following format:

Question: Question here
SQLQuery: SQL Query to run

Only use the following tables:
{schema}

Question: {question}
SQLQuery: """

In [5]:
import os
import re
import openai
from openai import OpenAI
import time
import pandas as pd

openai.api_key = os.environ["OPENAI_API_KEY"]


def generate_statement_vision(question, image_path, base64_image, llm_model):
    import shelve
    print("generate_statement_vision")
    print("HOLA")
    print(question)
    print(type(question))
    print(image_path)
    print(type(image_path))
    print(llm_model)
    print(type(llm_model))
    print("HOLA2")
    cache_key = f"{question}_{image_path}_{llm_model}"
    # cache_key = str(question)+str(study_id)+str(llm_model)
    print(cache_key)
    
    # Open the shelve database
    with shelve.open('vision_cache') as cache:
        # Check if the result is already in the cache
        if cache_key in cache:
            print("CACHE HIT!")
            return cache[cache_key]   
        print("NOT USING CACHE")
    
        # image_data = _load_image(image_path)
        print("HOLA3")
        if llm_model == "gpt-4o-2024-05-13":
            print("HOLA4")
            import openai
            from openai import OpenAI
            openai.api_key = os.environ["OPENAI_API_KEY"]
        
            client = OpenAI(
            organization="org-0ETBijh97KRCc66SwxdqZYas",
            project="proj_PkZ1sRwgxenG2CNRjBHMRYPg",
            )
            print("HOLA5")

            response = client.chat.completions.create(
                model=llm_model,  # Use the appropriate model for image analysis
                messages=[
                    {
                        "role": "system", "content": [
                                {
                                    "type": "text",
                                    "text": "You are a radiologist. Please answer the following question based on the image below. Provide the answer as a JSON list. For binary questions, the list should only contain 'yes' or 'no'. For a question that asks for specific findings, the list should contain all findings that you can see. If you are not sure, make your best guess and always provide an answer.\n\n"
                                }
                            ],
                    },
                    {
                        "role": "user", "content": [
                        {
                            "type": "text",  
                            "text": question
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_image}" ,
                                "detail": "low",
                            }
                        }
                        ]
                    }
                ],
                max_tokens=4096
            )
            print("HOLA5")
            print(response)
            resp_content = response.choices[0].message.content
        else:
            raise
        cache[cache_key] = resp_content
        print("SAVED TO CACHE")
        return resp_content



# def generate_statement_vision(question, image_path, base64_image, llm_model):
#     import shelve
#     print("generate_statement_vision")
#     print("HOLA")
#     print(question)
#     print(type(question))
#     print(image_path)
#     print(type(image_path))
#     print(llm_model)
#     print(type(llm_model))
#     print("HOLA2")
#     cache_key = f"{question}_{image_path}_{llm_model}"
#     # cache_key = str(question)+str(study_id)+str(llm_model)
#     print(cache_key)
    
#     if llm_model == "gpt-4o-2024-05-13":
#         with shelve.open('vision_cache_gpt4o') as cache:
#             # Check if the result is already in the cache
#             if cache_key in cache:
#                 print("CACHE HIT!")
#                 return cache[cache_key]   
#             print("NOT USING CACHE")
#             import openai
#             from openai import OpenAI
#             openai.api_key = os.environ["OPENAI_API_KEY"]
#             client = OpenAI(
#             organization="org-0ETBijh97KRCc66SwxdqZYas",
#             project="proj_PkZ1sRwgxenG2CNRjBHMRYPg",
#             )
#             print("HOLA5")

#             response = client.chat.completions.create(
#                 model=llm_model,  # Use the appropriate model for image analysis
#                 messages=[
#                     {
#                         "role": "system", "content": [
#                                 {
#                                     "type": "text",
#                                     "text": "You are a radiologist. Please answer the following question based on the image below. Provide the answer as a JSON list. For binary questions, the list should only contain 'yes' or 'no'. For a question that asks for specific findings, the list should contain all findings that you can see. If you are not sure, make your best guess and always provide an answer.\n\n"
#                                 }
#                             ],
#                     },
#                     {
#                         "role": "user", "content": [
#                         {
#                             "type": "text",  
#                             "text": question
#                         },
#                         {
#                             "type": "image_url",
#                             "image_url": {
#                                 "url": f"data:image/jpeg;base64,{base64_image}" ,
#                                 "detail": "low",
#                             }
#                         }
#                         ]
#                     }
#                 ],
#                 max_tokens=4096
#             )
#             print("HOLA5")
#             print(response)
#             resp_content = response.choices[0].message.content
#     else:
#         raise
#     cache[cache_key] = resp_content
#     print("SAVED TO CACHE")
#     return resp_content

def generate_statement_gpt(llm_model, formatted_prompt):
    client = OpenAI(
        organization="org-0ETBijh97KRCc66SwxdqZYas",
        project="proj_PkZ1sRwgxenG2CNRjBHMRYPg",
    )
    completion = client.chat.completions.create(
        model=llm_model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": formatted_prompt},
        ],
    )
    response = completion.choices[0].message.content
    return response.strip()


def extract_sql_query(text):
    if text.startswith("SQLQuery:"):
        return text.split(':', 1)[1].strip()
    # Find the SQL query within triple backticks
    match = re.search(r'```sql\n(.*?)```', text, re.DOTALL)
    if match:
        sql_query = match.group(1)
        # Remove leading/trailing whitespace and newlines
        sql_query = sql_query.strip()
        # Replace multiple whitespace characters with a single space
        sql_query = re.sub(r'\s+', ' ', sql_query)
        return sql_query
    else:
        return ""

In [6]:
def _get_image_paths(study_id):
    image_paths = []
    study_id_file_name = "s"+str(study_id)
    root_dir = "/Users/jofu/src/XMODE-LLMCompiler/files/"
    for root, dirs, files in os.walk(root_dir):
        if study_id_file_name in root:
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
                    image_paths.append(os.path.join(root, file))

    return image_paths


def _load_image(image_url):
    import base64
    try:
        with open(image_url, "rb") as image:
            image = base64.b64encode(image.read()).decode("utf-8")
            print("image read")
            return image
    except FileNotFoundError:
        raise FileNotFoundError(f"Image_path <{image_url}> not found")
    except base64.binascii.Error:
        raise ValueError(f"Image_path <{image_url}> is not a valid image")
    except Exception as e:
        raise e


In [7]:
import sqlparse
import sqlite3

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def func_vqa(question, study_id):
    import os
    import sys
    sys.path.append(os.path.dirname("tools/"))

    from vqa_m3ae import post_vqa_m3ae
    
    print(f"VQA called with question: '{question}' for study_id: {study_id}")
    llm_model = "gpt-4o-2024-05-13"
    # llm_model = "local_vqa"
    # check if paramaters correct:
    study_id_number = False
    try:
        float(study_id)
        study_id_number = True
    except ValueError:
        study_id_number = False

    question_number = False
    try:
        float(question)
        question_number = True
    except ValueError:
        question_number = False

    if question_number and not study_id_number:
        question_clean = study_id
        study_id_clean = question
    else:
        question_clean = question
        study_id_clean = study_id

    image_paths = _get_image_paths(study_id_clean)
    print(image_paths)
    print(question)
    for image_path in image_paths:
        if llm_model == "gpt-4o-2024-05-13":
            image_data = _load_image(image_path)
            resp = generate_statement_vision(question, image_path, image_data, llm_model)
        elif llm_model == "local_vqa":
            image_id = os.path.splitext(os.path.basename(image_path))[0]
            # image_id = '1b0b0385-a72d064d-be1f11ed-a39331d1-dde8f464'
            resp = post_vqa_m3ae(question, image_id)
            resp = resp["vqa_answers"]
        print("Response: ", resp)
        if "Yes" in resp or "yes" in resp:
            print("Yes")
            return 1
        elif "No" in resp or "no" in resp:
            print("No")
            return 0
        else:
            print("list of findings")
            return ', '.join(resp)

            # cannot return list of tuples...
            return [(item,) for item in resp]


In [8]:
import json
import pandas as pd
from collections import Counter


def execute_query(query, conn):
    try:
        return pd.read_sql_query(query, conn)
    except pd.io.sql.DatabaseError as e:
        error_msg = f"Database error occurred: {str(e)}"
        if 'sqlite3.OperationalError' in str(e):
            error_msg += "\nThis might be due to a syntax error in the SQL query."
        elif 'no such table' in str(e).lower():
            error_msg += "\nThis might be because the table doesn't exist in the database."
        elif 'no such column' in str(e).lower():
            error_msg += "\nThis might be because one of the columns in the query doesn't exist in the table."
        return pd.DataFrame({'error': [error_msg]})
    except ValueError as e:
        error_msg = f"Value error occurred: {str(e)}\nThis might be due to incompatible data types or NULL values in the result."
        return pd.DataFrame({'error': [error_msg]})
    except Exception as e:
        error_msg = f"An unexpected error occurred: {str(e)}"
        return pd.DataFrame({'error': [error_msg]})


def normalize_query(query):
    parsed = sqlparse.parse(query)[0]
    normalized = str(parsed).lower().replace(" ", "")
    return normalized

def calculate_query_accuracy(predicted_queries, expected_queries):
    correct = sum(1 for pred, exp in zip(predicted_queries, expected_queries)
                  if normalize_query(pred) == normalize_query(exp))
    return correct / len(predicted_queries) if predicted_queries else 0

def compare_results(df1, df2):
    # Convert DataFrames to sets of tuples for content comparison
    set1 = set(map(tuple, df1.values))
    set2 = set(map(tuple, df2.values))
    
    # Check if the sets have the same content
    return set1 == set2

def calculate_execution_accuracy(predicted_results, expected_results):
    correct = sum(1 for pred, exp in zip(predicted_results, expected_results)
                  if compare_results(pred, exp))
    return correct / len(predicted_results) if predicted_results else 0

In [9]:
def predict_queries(data):
    predicted_queries = []
    expected_queries = []
    ids = []
    
    print("Predicting queries:")
    for item in data:
        question = item['question']
        expected_query = item['query']
        _id = item["id"]
        
        # Generate the prompt
        prompt = prompt_template_v2.format(schema=db_schema, question=question)
        
        llm_model = "gpt-4o-2024-05-13"
        predicted_query = generate_statement_gpt(llm_model, prompt)
        print(predicted_query)
        predicted_query_clean = extract_sql_query(predicted_query)
        
        print(f"Question: {question}")
        print(f"Predicted Query: {predicted_query_clean}")
        print(f"Expected Query: {expected_query}")
        
        predicted_queries.append(predicted_query_clean)
        expected_queries.append(expected_query)
        ids.append(_id)

    return predicted_queries, expected_queries, ids

def execute_queries(predicted_queries, expected_queries):
    # Connect to the database
    conn = sqlite3.connect('mimic_iv_cxr.db')
    
    # Register the func_vqa function with SQLite
    conn.create_function("func_vqa", 2, func_vqa)
    
    
    predicted_results = []
    expected_results = []


    for predicted_query_clean, expected_query in zip(predicted_queries, expected_queries):
        predicted_query_clean = predicted_query_clean.replace("%y", "%Y")
        expected_query = expected_query.replace("%y", "%Y")
        try:
            if predicted_query_clean:
                # Execute predicted query
                pred_result = execute_query(predicted_query_clean, conn)
            else:
                pred_result = pd.DataFrame({'error': ["No SQL"]})
            predicted_results.append(pred_result)
    
            # Execute expected query
            exp_result = execute_query(expected_query, conn)
            expected_results.append(exp_result)
            
            # Check if results are equivalent
            results_match = compare_results(pred_result, exp_result)
            print(f"Results match: {results_match}")
    
        except sqlite3.Error as e:
            print(f"An error occurred: {e}")
            predicted_results.append(pd.DataFrame())
            expected_results.append(pd.DataFrame())
    
        print("-" * 80)

    # Close the database connection
    conn.close()
    return predicted_results, expected_results

In [10]:
TEST_DATA = "dataset/mimic_iv_cxr/sampled_test_with_scope_preprocessed_balenced_answer.json"

# Load the JSON data
with open(TEST_DATA, 'r') as file:
    data = json.load(file)
    data = data

predicted_queries, expected_queries, ids = predict_queries(data)


Predicting queries:
```sql
SELECT func_vqa("Is the cardiac silhouette's width larger than half of the total thorax width?", "STUDY_ID") 
FROM "TB_CXR" 
WHERE "SUBJECT_ID" = 10284038 
AND strftime('%Y', "STUDYDATETIME") = '2105' 
ORDER BY "STUDYDATETIME" DESC 
LIMIT 1;
```
Question: given the last study of patient 10284038 in 2105, is the cardiac silhouette's width larger than half of the total thorax width?
Predicted Query: SELECT func_vqa("Is the cardiac silhouette's width larger than half of the total thorax width?", "STUDY_ID") FROM "TB_CXR" WHERE "SUBJECT_ID" = 10284038 AND strftime('%Y', "STUDYDATETIME") = '2105' ORDER BY "STUDYDATETIME" DESC LIMIT 1;
Expected Query: select func_vqa("is the cardiac silhouette's width larger than half of the total thorax width?", t1.study_id) from ( select tb_cxr.study_id from tb_cxr where tb_cxr.study_id in ( select distinct tb_cxr.study_id from tb_cxr where tb_cxr.subject_id = 10284038 and strftime('%y',tb_cxr.studydatetime) = '2105' order by tb_

In [11]:
predicted_results, expected_results = execute_queries(predicted_queries, expected_queries)


VQA called with question: 'Is the cardiac silhouette's width larger than half of the total thorax width?' for study_id: 50487710
['/Users/jofu/src/XMODE-LLMCompiler/files/p10/p10284038/s50487710/43f51b0a-cc1f857d-a37a4186-7d42505a-92a2b7ec.jpg']
Is the cardiac silhouette's width larger than half of the total thorax width?
image read
generate_statement_vision
HOLA
Is the cardiac silhouette's width larger than half of the total thorax width?
<class 'str'>
/Users/jofu/src/XMODE-LLMCompiler/files/p10/p10284038/s50487710/43f51b0a-cc1f857d-a37a4186-7d42505a-92a2b7ec.jpg
<class 'str'>
gpt-4o-2024-05-13
<class 'str'>
HOLA2
Is the cardiac silhouette's width larger than half of the total thorax width?_/Users/jofu/src/XMODE-LLMCompiler/files/p10/p10284038/s50487710/43f51b0a-cc1f857d-a37a4186-7d42505a-92a2b7ec.jpg_gpt-4o-2024-05-13
CACHE HIT!
Response:  [
  "no"
]
No
VQA called with question: 'Is the cardiac silhouette's width larger than half of the total thorax width?' for study_id: 54252229
['/

In [12]:
predicted_results[0]

Unnamed: 0,"func_vqa(""Is the cardiac silhouette's width larger than half of the total thorax width?"", ""STUDY_ID"")"
0,0


In [13]:
expected_results[0]

Unnamed: 0,"func_vqa(""is the cardiac silhouette's width larger than half of the total thorax width?"", t1.study_id)"
0,0


In [14]:
df = pd.DataFrame({
    "id": ids,
    "query_predicted": predicted_queries,
    "query_expected": expected_queries,
    "result_predicted": predicted_results,
    "result_expected": expected_results,
})

In [15]:
# Calculate Query Accuracy

def calculate_query_accuracy(row):
    return bool(normalize_query(row["query_predicted"]) == normalize_query(row["query_expected"]))

df["query_accuracy"] = df.apply(calculate_query_accuracy, axis=1)


def calculate_execution_accuracy(row):
    return compare_results(row["result_predicted"], row["result_expected"])

df["execution_accuracy"] = df.apply(calculate_execution_accuracy, axis=1)
df["VQA"] = "GPT4o"

In [16]:
df

Unnamed: 0,id,query_predicted,query_expected,result_predicted,result_expected,query_accuracy,execution_accuracy,VQA
0,5,"SELECT func_vqa(""Is the cardiac silhouette's w...","select func_vqa(""is the cardiac silhouette's w...","func_vqa(""Is the cardiac silhouette's width...","func_vqa(""is the cardiac silhouette's width...",False,True,GPT4o
1,21,"SELECT func_vqa(""Can you confirm the presence ...","select (func_vqa(""can you confirm the presence...","Empty DataFrame Columns: [func_vqa(""Can you co...","Empty DataFrame Columns: [(func_vqa(""can you c...",False,True,GPT4o
2,49,"SELECT func_vqa(""Identify any abnormalities in...","select (func_vqa(""can you identify any abnorma...","Empty DataFrame Columns: [func_vqa(""Identify a...","(func_vqa(""can you identify any abnormaliti...",False,False,GPT4o
3,66,"SELECT ""study_id"", func_vqa(""has any tube or l...","select (func_vqa(""have any tubes/lines related...",...,"Empty DataFrame Columns: [(func_vqa(""have any ...",False,False,GPT4o
4,71,"SELECT ""STUDY_ID"", func_vqa(""airspace opacity ...","select (func_vqa(""can you find airspace opacit...",study_id ...,"(func_vqa(""can you find airspace opacity?"",...",False,False,GPT4o
5,111,"SELECT ""study_id"", func_vqa(""study_id"", ""What ...","select func_vqa(""please identify all anatomica...",...,"func_vqa(""please identify all anatomical sit...",False,False,GPT4o
6,163,"SELECT func_vqa(""enumerate all detected abnorm...","select func_vqa(""enumerate all detected abnorm...","func_vqa(""enumerate all detected abnormaliti...","func_vqa(""enumerate all detected abnormaliti...",False,False,GPT4o
7,439,"SELECT ""VALUE"" FROM func_vqa(""What are the ana...","select func_vqa(""specify all anatomical locati...",...,"func_vqa(""specify all anatomical locations a...",False,False,GPT4o
8,460,SELECT func_vqa('Which anatomical area is asso...,"select func_vqa(""which anatomical area is asso...",Anatomica...,"func_vqa(""which anatomical area is associate...",False,False,GPT4o
9,518,"SELECT func_vqa(""What are the observed abnorma...","select func_vqa(""outline all the observed abno...","func_vqa(""What are the observed abnormaliti...","func_vqa(""outline all the observed abnormal...",False,True,GPT4o


In [17]:
df.iloc[2]["query_predicted"]

'SELECT func_vqa("Identify any abnormalities in the abdomen or right chest wall", "STUDY_ID") FROM "TB_CXR" WHERE "SUBJECT_ID" = 15491652 AND strftime(\'%Y-%m\', "STUDYDATETIME") = \'2015-06\' ORDER BY "STUDYDATETIME" DESC LIMIT 1;'

In [18]:
df.iloc[2]["query_expected"]

'select (func_vqa("can you identify any abnormalities in the abdomen?", t1.study_id)) or (func_vqa("can you identify any abnormalities in the right chest wall?", t1.study_id)) from ( select tb_cxr.study_id from tb_cxr where tb_cxr.study_id in ( select distinct tb_cxr.study_id from tb_cxr where tb_cxr.subject_id = 15491652 and strftime(\'%y-%m\',tb_cxr.studydatetime) = \'2105-06\' order by tb_cxr.studydatetime desc limit 1 ) ) as t1'

In [19]:
df.iloc[2]["result_predicted"]

Unnamed: 0,"func_vqa(""Identify any abnormalities in the abdomen or right chest wall"", ""STUDY_ID"")"


In [20]:
df.iloc[2]["result_expected"]

Unnamed: 0,"(func_vqa(""can you identify any abnormalities in the abdomen?"", t1.study_id)) or (func_vqa(""can you identify any abnormalities in the right chest wall?"", t1.study_id))"
0,1


In [21]:
overall_accuracy = df["query_accuracy"].mean()
print(f"Overall query accuracy: {overall_accuracy:.2%}")

overall_accuracy = df["execution_accuracy"].mean()
print(f"Overall execution accuracy: {overall_accuracy:.2%}")

Overall query accuracy: 0.00%
Overall execution accuracy: 13.33%


In [None]:
raise

In [None]:
df.to_csv("Experiments/ehrxqa_results_local_vqa.csv", index=False)

In [None]:
import sqlite3
import json

conn = sqlite3.connect('mimic_iv_cxr.db')
conn.create_function("func_vqa", 2, func_vqa)

TEST_DATA = "dataset/mimic_iv_cxr/sampled_test_with_scope_preprocessed_balenced_answer.json"

# Load the JSON data
with open(TEST_DATA, 'r') as file:
    data = json.load(file)
    data = data

test_query = data[4]["query"]
test_query = df.iloc[0]["query_expected"]
test_query = test_query.replace("%y", "%Y")

print(test_query)

cursor = conn.cursor()
result = cursor.execute(test_query).fetchone()
print("Test query result:", result)

conn.close()

In [None]:
import sqlite3
import json

conn = sqlite3.connect('mimic_iv_cxr.db')
conn.create_function("func_vqa", 2, func_vqa)

TEST_DATA = "dataset/mimic_iv_cxr/sampled_test_with_scope_preprocessed_balenced_answer.json"

# Load the JSON data
with open(TEST_DATA, 'r') as file:
    data = json.load(file)
    data = data

test_query = data[4]["query"]
test_query = df.iloc[0]["query_predicted"]
test_query = test_query.replace("%y", "%Y")

print(test_query)

cursor = conn.cursor()
result = cursor.execute(test_query).fetchone()
print("Test query result:", result)

conn.close()