In [4]:
import json

from constants import DEV_JSON_PATH, TEST_JSON_PATH

with open(TEST_JSON_PATH,"r") as input_file:
    dev_set = json.load(input_file)    
answer_set=[]
with open("./0328_test_end2end.txt","r") as input_file:
    for line in input_file.readlines():
        answer_set.append(line.strip())

len(answer_set)


2147

In [5]:
import sqlite3
from constants import DATABASE_PATH_PATTERN
import util
import prompt

VALID_INSTRUCTION = """
Given [Database Schema] and [Foreign Keys], your task is to write a [SQL Query] to answer the [Question].

[Database Schema] Every table consists of several columns. Each line describes the column name, column type and optional value examples.
{schema}
[Foreign keys]
{foreign_key}
[Question]
{question}
[Constraints] Your [SQL Query] should satisfy the following constraints:
- In `SELECT <column>`, only use the column given in the [Database Schema].
- In `FROM <table>` or `JOIN <table>`, only use the table given in the [Database Schema].
- In `JOIN`, only use the tables and columns in the [Foreign keys].
- Without any specific instructions, Use `ASC` for `ORDER BY` by default, 
- Consider use `DISTINCT` when you need to eliminate duplicates.
- The content in quotes is case sensitive.
- Prioritize column whose value examples are more relevant to the [Question].

[Wrong SQL Query]
{wrong_query}
[Error Information] Database server return following error information when execute above [Wrong SQL Query], this may help you write the right SQL query.
{error_info}
[SQL Query] this right sql query is: 
""".lstrip()

STEP_2_OUTPUT_PATTERN = """
[SQL Query]
{answer}
""".strip()


def execute(db_path: str, sql: str):

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    try:
        cursor.execute(sql)
        return ""
    except sqlite3.Error as e:
        return str(e)
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


def valid_prompt(
    db_path: str, question: str, comment_map: dict, wrong_query: str, error_info: str
):
    tables = util.get_sqlite_schema_table_with_type_map(db_path, [])

    schema_str = prompt.build_schema_str(db_path, tables, question, comment_map, 3)
    foreign_key = prompt.build_foreign_key_str(db_path, list(tables.keys()))
    instruction = VALID_INSTRUCTION.format(
        schema=schema_str,
        foreign_key=foreign_key,
        question=question,
        wrong_query=wrong_query,
        error_info=error_info,
    )

    return instruction


valid_dataset = []
for i in range(len(answer_set)):
    info = dev_set[i]
    db_path = DATABASE_PATH_PATTERN.format(db_id=info["db_id"])

    raw_sql = answer_set[i]
    error = execute(db_path, raw_sql)
    if error != "":
        instruction = valid_prompt(db_path, info["question"], {}, raw_sql, error)
        valid_dataset.append({"instruction": instruction, "index": i})
        # print(instruction)
with open("./valid_dataset.json", "w") as output_file:
    json.dump(valid_dataset, output_file)

In [6]:
with open("valid_test.json", "r") as input_file:
    valid_result = json.load(input_file)
print(len(valid_result))
for i in range(len(valid_result)):
    case = valid_result[i]
    sql = case[case.find("SELECT "):]
    sql = " ".join(sql.split())
    valid_dataset[i]["answer"] = sql
    print(valid_dataset[i]["index"])
    answer_set[valid_dataset[i]["index"]] = sql

fo = open("new_answer.txt", "w")
for case in answer_set:
    fo.write(case + "\n")
fo.close()


52
101
378
442
471
583
585
588
695
729
731
733
734
737
738
757
758
805
806
809
810
835
901
911
912
952
961
1033
1034
1044
1157
1159
1269
1349
1356
1357
1388
1389
1415
1508
1547
1548
1705
1706
1909
1919
1920
1965
1966
2013
2142
2145
2146
