In [1]:
import os
from openai import OpenAI
from sentence_transformers import SentenceTransformer, util
import requests
import json

In [2]:
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4o"
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

### UTILS


In [3]:
def execute_sql_query(sql_query):
    response = requests.post(
        "http://localhost:3000/execute-query",
        json={"query": sql_query},
    ).json()
    return response


def create_initial_sqlquery(context_aware_prompt):
    completion = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": "You are an AI assistant specialized in rewriting database queries.",
            },
            {"role": "user", "content": context_aware_prompt},
        ],
        temperature=0.3,
    )

    return completion.choices[0].message.content


def close_database():
    response = requests.post("http://localhost:3000/close-database").json()
    return response

### STEP 1: User Inputs


### Create the schema


In [None]:
def initialize_database():
    response = requests.post("http://localhost:3000/initialize-database").json()
    if response["status"] == False:
        return response["details"]


def create_schema(schema):
    response = requests.post(
        "http://localhost:3000/create-schema",
        json={"schema": schema},
    ).json()

    return response

def get_tables():
    sql_query_to_retrieve_tables = "SELECT name FROM sqlite_master WHERE type='table';"
    sql_query_output = execute_sql_query(sql_query_to_retrieve_tables)["output"]
    all_tables = []
    for x in sql_query_output:
        all_tables.append(x["name"])
    return all_tables


def get_columns(table_name):
    sql_query_to_retrieve_columns = f"PRAGMA table_info({table_name});"
    sql_query_output = execute_sql_query(sql_query_to_retrieve_columns)["output"]
    all_columns = []
    for x in sql_query_output:
        all_columns.append(x["name"])
    return all_columns


def get_rows(table_name):
    sql_query_to_retrieve_rows = f"SELECT * FROM {table_name};"
    sql_query_output = execute_sql_query(sql_query_to_retrieve_rows)["output"]
    return sql_query_output

### STEP 2: Question Rewriting


In [None]:
def sort_rows(nlq, rows):
    # Compute the embeddings for the NLQ and the rows
    nlq_embedding = embedding_model.encode(nlq, convert_to_tensor=True)
    rows_embeddings = [
        embedding_model.encode(str(row), convert_to_tensor=True) for row in rows
    ]

    # Compute the cosine similarity between the NLQ and the rows
    similarity = [
        util.pytorch_cos_sim(nlq_embedding, row_embedding)
        for row_embedding in rows_embeddings
    ]

    # Sort the rows based on the similarity
    sorted_rows = [row for _, row in sorted(zip(similarity, rows), reverse=True)]
    return sorted_rows

In [6]:
def question_rewrite(nlq, five_rows, columns):
    # Construct the prompt for zero-shot question rewriting
    prompt = f"""
    You are an AI assistant tasked with ensuring the given question aligns with the provided database schema and data. Your role is to:

    1. Analyze the question for clarity and accuracy.
    2. Ensure the question matches the terminology, structure, and relationships (e.g., foreign keys) in the database schema.
    3. Rewrite the question if needed to make it unambiguous and aligned with the schema.
    4. If the question is already clear and aligned, return it without changes.

    ### Instructions:
    - Only rewrite the question if necessary for clarity or schema alignment.
    - Maintain the intent of the original question.
    - Do not modify the schema or sample data.

    ### Input Details:
    **Database Schema:**
    {columns}

    **Sample Data:**
    {five_rows}

    **Original Question:**
    {nlq}

    ### Output:
    - Provide the updated question if rewritten.
    - If no changes are needed, return the original question as is.
    """

    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "user", "content": prompt},
        ],
        temperature=0.3,
    )
    return completion.choices[0].message.content.strip()

### STEP 3: Context-Aware Prompt Generation


In [7]:
def generate_context_aware_prompt(nlq, top_five_rows, columns):
    prompt = f"""
    You are tasked with generating a valid SQL query based on the provided question, database schema, and sample data. Follow these strict guidelines:

    1. **No LEFT JOIN**: Avoid using the LEFT JOIN keyword.
    2. **No Aliases for Aggregate Functions**: Do not assign aliases to aggregate functions.
    3. **Simplicity**: Focus on generating the simplest query that satisfies the question.
    4. **Output Format**: Do not include additional text, explanations, or formatting like ```sql, code, or markdown```. Always return a valid SQL query ending with a semicolon.

    ### Hints:
    - Use COUNT(*) instead of using column names in the COUNT function.
    - Use OR instead of keywords like IN or BETWEEN for better performance.

    ### Input Details:
    **Database Schema:**
    {columns}

    **Sample Data:**
    {top_five_rows}

    **Question:**
    {nlq}

    ### Output:
    A single valid SQL query that adheres to the above rules.
    """

    return prompt

### STEP 4: Execution-Guided Refinemen


In [8]:
def execution_guided_refinement(context_aware_prompt, nlq, columns):
    sql_query = create_initial_sqlquery(context_aware_prompt)
    sql_query_output = execute_sql_query(sql_query=sql_query)["output"]

    for _ in range(5):
        refined_sql_query, refined_sql_query_output = refinement_iteration(
            nlq, sql_query, sql_query_output, columns
        )
        if refined_sql_query.strip() == sql_query.strip() or len(
            refined_sql_query_output
        ) == len(sql_query_output):
            break
        sql_query = refined_sql_query
        sql_query_output = refined_sql_query_output

    return sql_query, sql_query_output


def refinement_iteration(nlq, sql_query, sql_query_output, columns):
    prompt = f"""
    You are tasked with verifying and correcting an SQL query based on the provided natural language question, its current output, and the database schema. Follow these strict rules:

    ### Input:
    1. **Natural Language Question**: The original query intent.
    2. **SQL Query**: The current SQL query to verify.
    3. **Query Output**: The result of executing the SQL query (can be empty if no rows satisfy the query).
    4. **Database Schema**: Structure of the database tables.
    
    ### Output:
    - If the query and results are both correct, return the original query without changes.
    - Ensure column references are consistent with the schema and tables. Always return a valid SQL query ending with a semicolon.
    - Do not include additional text, explanations, or formatting like ```sql, code, or markdown```.

    ### Provided Data:
    **Database Schema:**
    {columns}

    **Natural Language Question:**
    {nlq}

    **SQL Query:**
    {sql_query}

    **Query Output:**
    {sql_query_output}

    ### Output:
    - A valid SQL query that satisfies the above instructions.
    """

    # Call the LLM to refine the SQL query
    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "user", "content": prompt},
        ],
        temperature=0.3,
    )

    refined_query = completion.choices[0].message.content

    # Execute the refined SQL query to get feedback
    refined_sql_query_output = execute_sql_query(refined_query)

    return refined_query, refined_sql_query_output

In [None]:
if __name__ == "__main__":

    with open("dev.json") as f:
        dev = json.load(f)

    for i in range(0, 100):
        database_id = dev[i]["db_id"]
        nlq = dev[i]["question"]
        dir = f"database/{database_id}/"
        for path in os.listdir(dir):
            if path.endswith(".sql"):
                with open(dir + path) as f:
                    schema = f.read()
                    break

        # Initialize the database
        initialize_database()

        # Create the schema
        create_schema(schema)

        # Get the table names
        tables = get_tables()

        # Get the columns and rows for each table
        columns = {}
        for table in tables:
            columns[table] = get_columns(table)
        rows = {}
        for table in tables:
            rows[table] = get_rows(table)

        # Sort the rows based on the similarity with the NLQ
        sorted_rows = {}
        for table in tables:
            sorted_rows[table] = sort_rows(nlq, rows[table])

        # Rewrite the NLQ based on 5 random rows from each table
        five_rows = {}
        for table in tables:
            five_rows[table] = rows[table][:5]
        rewriten_nlq = question_rewrite(nlq=nlq, five_rows=five_rows, columns=columns)

        # Get the top 5 similar rows for each table
        top_five_rows = {}
        for table in tables:
            top_five_rows[table] = sorted_rows[table][:5]

        # Generate the context-aware prompt
        context_aware_prompt = generate_context_aware_prompt(
            nlq=rewriten_nlq, top_five_rows=top_five_rows, columns=columns
        )

        # Execute the execution-guided refinement process
        final_sql_query, final_sql_query_output = execution_guided_refinement(
            context_aware_prompt=context_aware_prompt, nlq=nlq, columns=columns
        )

        # Write the final SQL query to a file
        final_sql_query = final_sql_query.replace("\n", " ")
        with open("predictions.sql", "a") as f:
            f.write(final_sql_query)
            f.write("\n")

        # Close the database
        close_database()