In [None]:
import json
import random
import threading
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai_access import openai_completion

paths = ["<anonymous>"]

data = []
for p in paths:
    with open(p, "r") as f:
        data += json.load(f)

print(len(data))
all_db_ids = list(set(dp["db_id"] for dp in data))
print(len(all_db_ids))

In [None]:
print(len(all_db_ids))
train_db_ids = random.sample(all_db_ids, k=65)
eval_db_ids = list(set(all_db_ids) - set(train_db_ids))
print(len(train_db_ids), len(eval_db_ids))

dp_count_train = len([1 for dp in data if dp["db_id"] in train_db_ids])
dp_count_eval = len([1 for dp in data if dp["db_id"] in eval_db_ids])
print(dp_count_train, dp_count_eval)

In [3]:
COT_TEMPLATE_WO_HINT = """\
For a Text-to-SQL scenario, I will provide you:
- a database schema
- a user question
- two SQL queries: one correct and one incorrect, without indicating which is which.
- the execution result of the two SQL queries
Your task is to identify the correct SQL query.
Please carefully analyze the semantic alignment between the SQL queries and the user's question, and examine the differences between the SQL queries.
Requirements (in order of priority):
- (Priority) Judge based on how accurately the SQL answers the user's question
- If the SQL execution result is empty or None, then this SQL is likely to be incorrect.
- Using more tables and columns doesn't definitely make an SQL query correct - consider how well it matches the user's question
- We prefer SQL queries whose results can directly serve as the answer to the user's question
- If neither SQL is ideal, choose the better one that align with the user's question

Please analyze and provide output in this specific format:
1. Question Analysis
2. Semantic and logical observation of SQL 1 (at this stage, avoid discussing potential errors or making conclusions)
3. Semantic and logical observation of SQL 2 (at this stage, avoid discussing potential errors or making conclusions)
4. Analysis of differences between SQL 1 and SQL 2 (create a detailed Markdown table for comparison)
5. Judgement (concluding with either \\box{{SQL1}} or \\box{{SQL2}} as the result).

[Database Schema]
{schema}
[User question]
{question}
[SQL 1]
{sql1}
[SQL 1 Execution Result]
{sql1_result}
[SQL 2]
{sql2}
[SQL 2 Execution Result]
{sql2_result}
"""

COT_TEMPLATE_WITH_HINT = """\
For a Text-to-SQL scenario, I will provide you:
- a database schema
- a user question
- a important Hint
- two SQL queries: one correct and one incorrect, without indicating which is which.
- the execution result of the two SQL queries
Your task is to identify the correct SQL query.
Please carefully analyze the semantic alignment between the SQL queries and the user's question, and examine the differences between the SQL queries.
Requirements (in order of priority):
- (Priority) If a Hint specifies the solution approach or provides calculation logic, the SQL query that follows the Hint is more likely to be correct
- (Priority) Judge based on how accurately the SQL answers the user's question
- If the SQL execution result is empty or None, then this SQL is likely to be incorrect.
- Using more tables and columns doesn't definitely make an SQL query correct - consider how well it matches the user's question and Hint
- We prefer SQL queries whose results can directly serve as the answer to the user's question
- If neither SQL is ideal, choose the better one based on how well it matches the Hint and the user's question

Please analyze and provide output in this specific 5-step format:
1. Question Analysis
2. Semantic and logical observation of SQL 1 (at this stage, avoid discussing potential errors or making conclusions)
3. Semantic and logical observation of SQL 2 (at this stage, avoid discussing potential errors or making conclusions)
4. Analysis of differences between SQL 1 and SQL 2 (create a detailed Markdown table for comparison)
5. Judgement (concluding with either \\box{{SQL1}} or \\box{{SQL2}} as the result).

[Database Schema]
{schema}
[User question]
{question}
[Hint]
{hint}
[SQL 1]
{sql1}
[SQL 1 Execution Result]
{sql1_result}
[SQL 2]
{sql2}
[SQL 2 Execution Result]
{sql2_result}
"""


def generate_COT_prompt(datapoint: dict) -> str:
    schema = datapoint["schema"]
    question = datapoint["question"]
    hint = datapoint["evidence"]
    sql1 = datapoint["sql1"]
    sql1_result = datapoint["sql1_result"]
    sql2 = datapoint["sql2"]
    sql2_result = datapoint["sql2_result"]

    if len(hint.strip()) > 4:
        return COT_TEMPLATE_WITH_HINT.format(
            schema=schema,
            question=question,
            hint=hint,
            sql1=sql1,
            sql1_result=sql1_result,
            sql2=sql2,
            sql2_result=sql2_result,
        )
    else:
        return COT_TEMPLATE_WO_HINT.format(
            schema=schema,
            question=question,
            sql1=sql1,
            sql1_result=sql1_result,
            sql2=sql2,
            sql2_result=sql2_result,
        )

In [4]:
def reject_sample(datapoint):
    # Correct Pool
    all_correct_sqls = datapoint["correct_beam"] + datapoint["correct_sample"]
    all_correct_sqls.append({"q": datapoint["gold_query"], "r": datapoint["formatted_gold_results"]})

    # Wrong Pool
    all_wrong_sqls = datapoint["wrong_beam"] + datapoint["wrong_sample"]

    # Add some randomly wired SQLs
    all_wired_sqls = datapoint["wired_beam"] + datapoint["wired_sample"]
    wired_sample_num = min(len(all_wired_sqls), 3)
    wired_sample = random.sample(all_wired_sqls, wired_sample_num)
    all_wrong_sqls += wired_sample

    if len(all_wrong_sqls) == 0:
        return None

    for _ in range(3):
        # Randomly select a correct and wrong
        correct_sql = random.choice(all_correct_sqls)
        wrong_sql = random.choice(all_wrong_sqls)

        # Generate the prompt
        if random.random() < 0.5:
            info = {
                "schema": datapoint["augmented_schema"],
                "question": datapoint["question"],
                "evidence": datapoint["evidence"],
                "sql1": correct_sql["q"],
                "sql1_result": correct_sql["r"],
                "sql2": wrong_sql["q"],
                "sql2_result": wrong_sql["r"],
            }
            prompt = generate_COT_prompt(info)
            label = r"\box{SQL1}"
        else:
            info = {
                "schema": datapoint["augmented_schema"],
                "question": datapoint["question"],
                "evidence": datapoint["evidence"],
                "sql1": wrong_sql["q"],
                "sql1_result": wrong_sql["r"],
                "sql2": correct_sql["q"],
                "sql2_result": correct_sql["r"],
            }
            prompt = generate_COT_prompt(info)
            label = r"\box{SQL2}"

        response = openai_completion(prompt, model="gpt-4o")

        # Do Rejection Sampling
        if label in response:
            return {
                "db_id": datapoint["db_id"],
                "schema": datapoint["augmented_schema"],
                "question": datapoint["question"],
                "evidence": datapoint["evidence"],
                "gold_query": datapoint["gold_query"],
                "sql1": info["sql1"],
                "sql1_result": info["sql1_result"],
                "sql2": info["sql2"],
                "sql2_result": info["sql2_result"],
                "prompt": prompt,
                "response": response,
                "label": label,
            }

    return None

In [None]:
all_cot_data = []
cot_data_lock = threading.Lock()

with ThreadPoolExecutor(max_workers=50) as executor:
    futures = [executor.submit(reject_sample, datapoint) for datapoint in data]
    for future in tqdm(as_completed(futures), total=len(futures)):
        cot_datum = future.result()
        if cot_datum:
            with cot_data_lock:
                all_cot_data.append(cot_datum)

with open("./cot_train_data.json", "w") as f:
    json.dump(all_cot_data, f, indent=2)

In [10]:
with open("./cot_train_data.json", "w") as f:
    json.dump(all_cot_data, f, indent=2)

In [11]:
count = 0
for dp in all_cot_data:
    if (r"\box{SQL1}" not in dp["response"]) and (r"\box{SQL2}" not in dp["response"]):
        count += 1
print(count)

0


In [8]:
d = all_cot_data[0]
print(d["prompt"])
print("=" * 100)
print(d["response"])

For a Text-to-SQL scenario, I will provide you:
- a database schema
- a user question
- a important Hint
- two SQL queries: one correct and one incorrect, without indicating which is which.
- the execution result of the two SQL queries
Your task is to identify the correct SQL query.
Please carefully analyze the semantic alignment between the SQL queries and the user's question, and examine the differences between the SQL queries.
Requirements (in order of priority):
- (Priority) If a Hint specifies the solution approach or provides calculation logic, the SQL query that follows the Hint is more likely to be correct
- (Priority) Judge based on how accurately the SQL answers the user's question
- If the SQL execution result is empty or None, then this SQL is likely to be incorrect.
- Using more tables and columns doesn't definitely make an SQL query correct - consider how well it matches the user's question and Hint
- We prefer SQL queries whose results can directly serve as the answer to