## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict, get_schema_str_with_tables
from src.database import SqliteDatabase, DuckDBDatabase


## Schema description

In [2]:
import sqlite3
db_path = 'data/tpch/TPC-H.db'
def get_database_schema(db_path, tables_list) -> str:

    stmt = ''

    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    # Fetch names of all tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cur.fetchall()

    # Fech create statements for all tables
    for table in tables:
        table_name = table[0]
        if tables_list and table_name not in tables_list:
            continue
        cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
        create_statement = cur.fetchone()[0]

        stmt += create_statement + '\n\n'

    conn.close()
    return stmt

print(get_database_schema(db_path=db_path, tables_list=[]))

CREATE TABLE NATION (
  N_NATIONKEY INTEGER PRIMARY KEY NOT NULL,
  N_NAME      TEXT NOT NULL,
  N_REGIONKEY INTEGER NOT NULL,
  N_COMMENT   TEXT,
  FOREIGN KEY (N_REGIONKEY) REFERENCES REGION(R_REGIONKEY)
)

CREATE TABLE REGION (
  R_REGIONKEY INTEGER PRIMARY KEY NOT NULL,
  R_NAME      TEXT NOT NULL,
  R_COMMENT   TEXT
)

CREATE TABLE PART (
  P_PARTKEY     INTEGER PRIMARY KEY NOT NULL,
  P_NAME        TEXT NOT NULL,
  P_MFGR        TEXT NOT NULL,
  P_BRAND       TEXT NOT NULL,
  P_TYPE        TEXT NOT NULL,
  P_SIZE        INTEGER NOT NULL,
  P_CONTAINER   TEXT NOT NULL,
  P_RETAILPRICE INTEGER NOT NULL,
  P_COMMENT     TEXT NOT NULL
)

CREATE TABLE SUPPLIER (
  S_SUPPKEY   INTEGER PRIMARY KEY NOT NULL,
  S_NAME      TEXT NOT NULL,
  S_ADDRESS   TEXT NOT NULL,
  S_NATIONKEY INTEGER NOT NULL,
  S_PHONE     TEXT NOT NULL,
  S_ACCTBAL   INTEGER NOT NULL,
  S_COMMENT   TEXT NOT NULL,
  FOREIGN KEY (S_NATIONKEY) REFERENCES NATION(N_NATIONKEY)
)

CREATE TABLE PARTSUPP (
  PS_PARTKEY    IN

## Q2: Minimum Cost Supply Query

In [3]:
NLquery1 = "Find top 100 suppliers with the highest account balance \
that I can place an order for parts of type includes COPPER, \
with size 42 in region EUROPE, and with the minimum cost supply."

NLquery12 = "Find suppliers with the minimum cost supply \
that I can place an order for parts of type includes COPPER, \
with size 42 in region EUROPE. Return top 100 suppliers with the highest account balance."

gold_sql1 = """select
	s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment
from
	part, supplier, partsupp, nation, region
where
	p_partkey = ps_partkey
	and s_suppkey = ps_suppkey
	and p_size = 42
	and p_type like '%COPPER'
	and s_nationkey = n_nationkey
	and n_regionkey = r_regionkey
	and r_name = 'EUROPE'
	and ps_supplycost = (
		select
			min(ps_supplycost)
		from
			partsupp,
			supplier,
			nation,
			region
		where
			p_partkey = ps_partkey
			and s_suppkey = ps_suppkey
			and s_nationkey = n_nationkey
			and n_regionkey = r_regionkey
			and r_name = 'EUROPE'
	)
order by
	s_acctbal desc,
	n_name,
	s_name,
	p_partkey
LIMIT 100;"""

NLquery2 = "Find top 10 suppliers with the highest account balance \
that I can place an order in ASIA region for parts of type includes BRASS, \
with size 15, and with the minimum cost supply."

gold_sql2 = """select
	s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment
from
	part, supplier, partsupp, nation, region
where
	p_partkey = ps_partkey
	and s_suppkey = ps_suppkey
	and p_size = 15
	and p_type like '%BRASS'
	and s_nationkey = n_nationkey
	and n_regionkey = r_regionkey
	and r_name = 'ASIA'
	and ps_supplycost = (
		select
			min(ps_supplycost)
		from
			partsupp,
			supplier,
			nation,
			region
		where
			p_partkey = ps_partkey
			and s_suppkey = ps_suppkey
			and s_nationkey = n_nationkey
			and n_regionkey = r_regionkey
			and r_name = 'ASIA'
	)
order by
	s_acctbal desc,
	n_name,
	s_name,
	p_partkey
LIMIT 10;"""

NLsamples = [{'question': NLquery1, 'sql': gold_sql1},
             {'question': NLquery12, 'sql': gold_sql1}]

In [4]:
database = SqliteDatabase(db_file=db_path)
database.execute(gold_sql1)

Unnamed: 0,S_ACCTBAL,S_NAME,N_NAME,P_PARTKEY,P_MFGR,S_ADDRESS,S_PHONE,S_COMMENT
0,9967.45,Supplier#000002302,FRANCE,69795,Manufacturer#5,wMEzrsX2KKpTaJGE3uGEUibymG,16-486-165-5642,gly carefully bold deposits. accounts nag b
1,9925.04,Supplier#000003400,ROMANIA,73399,Manufacturer#1,IZSzKpRL1RNar39LvF,29-295-531-2833,unts along the ironic accounts must have to ha...
2,9915.38,Supplier#000006085,ROMANIA,31078,Manufacturer#2,T3Ju68MUhIb2hpTO3f8OGG,29-528-113-9241,iously bold sauternes. slyly regular asymptotes
3,9915.38,Supplier#000006085,ROMANIA,198527,Manufacturer#5,T3Ju68MUhIb2hpTO3f8OGG,29-528-113-9241,iously bold sauternes. slyly regular asymptotes
4,9828.21,Supplier#000000647,UNITED KINGDOM,23140,Manufacturer#3,x5U7MBZmwfG9,33-258-202-4782,s the slyly even ideas poach fluffily
...,...,...,...,...,...,...,...,...
95,7822.90,Supplier#000000674,FRANCE,180673,Manufacturer#3,jMxLRDxoP1Pf kzzyMVIfLB,16-128-338-8014,thely after the furiously even pains. quietly
96,7814.84,Supplier#000004126,ROMANIA,181607,Manufacturer#2,3s9EL2QxI5lAEeSPr9aDv0 O0X7SP PA4TQWAAYn,29-497-666-4765,fluffy packages. furiously ironic r
97,7814.08,Supplier#000006465,UNITED KINGDOM,33961,Manufacturer#4,"R0ofppl4Gkm,b,U5uCA0YL9wm3el luro0T",33-155-333-2168,ious dependencies. slyly regular depths doubt....
98,7814.08,Supplier#000006465,UNITED KINGDOM,68946,Manufacturer#3,"R0ofppl4Gkm,b,U5uCA0YL9wm3el luro0T",33-155-333-2168,ious dependencies. slyly regular depths doubt....


In [8]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

## Chain of Thought Prompt: gpt4o-mini

In [20]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))

100%|██████████| 2/2 [00:06<00:00,  3.01s/it]

'sql_idx': 0
'db_id': tpc-h
'question': Find top 100 suppliers with the highest account balance that I can place an order for parts of type includes COPPER, with size 42 in region EUROPE, and with the minimum cost supply.
'rationale': 1. We need to find suppliers who can provide parts of type 'COPPER' and size 42. This requires joining the SUPPLIER table with the PARTSUPP and PART tables.
2. We also need to filter suppliers based on the region 'EUROPE', which requires joining the SUPPLIER table with the NATION table and then filtering by the region.
3. The account balance of suppliers is the key metric we want to sort by, so we will order the results by S.S_ACCTBAL in descending order.
4. Finally, we limit the results to the top 100 suppliers.
'full_sql_query': SELECT S.S_SUPPKEY, S.S_NAME, S.S_ACCTBAL
FROM SUPPLIER S
JOIN NATION N ON S.S_NATIONKEY = N.N_NATIONKEY
JOIN PARTSUPP PS ON S.S_SUPPKEY = PS.PS_SUPPKEY
JOIN PART P ON PS.PS_PARTKEY = P.P_PARTKEY
WHERE P.P_TYPE LIKE '%COPPER%' A




In [None]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    pred_result = database.execute(data['full_sql_query'])
    print('pred_result\n', pred_result)
    gold_result = database.execute(data['gold_sql'])
    print('gold_result\n', gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


## Chain-of-Thought: gpt-4o

In [21]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))

100%|██████████| 2/2 [00:10<00:00,  5.02s/it]

'sql_idx': 0
'db_id': tpc-h
'question': Find top 100 suppliers with the highest account balance that I can place an order for parts of type includes COPPER, with size 42 in region EUROPE, and with the minimum cost supply.
'rationale': 1. Identify the relevant tables: SUPPLIER, PART, PARTSUPP, NATION, REGION.
2. Join SUPPLIER with NATION to get the region information.
3. Join NATION with REGION to filter suppliers in the EUROPE region.
4. Join SUPPLIER with PARTSUPP to get the supply cost and part information.
5. Join PARTSUPP with PART to filter parts of type including 'COPPER' and size 42.
6. Filter the results to get the minimum supply cost for each supplier.
7. Order the results by account balance in descending order.
8. Limit the results to the top 100 suppliers.
'full_sql_query': SELECT S.S_SUPPKEY, S.S_NAME, S.S_ACCTBAL, MIN(PS.PS_SUPPLYCOST) AS MIN_SUPPLYCOST
FROM SUPPLIER S
JOIN NATION N ON S.S_NATIONKEY = N.N_NATIONKEY
JOIN REGION R ON N.N_REGIONKEY = R.R_REGIONKEY
JOIN PARTSU




In [None]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    pred_result = database.execute(data['full_sql_query'])
    print('pred_result\n', pred_result)
    gold_result = database.execute(data['gold_sql'])
    print('gold_result\n', gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


### Chain-of-Thought: o1-preview

In [38]:
from openai import OpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_O1_KEY"))

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    formatted_prompt = prompt.format(schema=db_schema, input_query=x['question'])
    #print(formatted_prompt)
    
    response = client.chat.completions.create(
        model="o1-preview",
        messages=[
            {
                "role": "user", 
                "content": formatted_prompt
            }
        ]
    )
    response = response.choices[0].message.content
    from langchain_core.utils.json import parse_json_markdown
    output = parse_json_markdown(response)
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output['rationale']
    full_sql_output['full_sql_query'] = output['full_sql_query']
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)

for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))


100%|██████████| 2/2 [01:36<00:00, 48.10s/it]

'sql_idx': 0
'db_id': tpc-h
'question': Find top 100 suppliers with the highest account balance that I can place an order for parts of type includes COPPER, with size 42 in region EUROPE, and with the minimum cost supply.
'rationale': Step 1: Identify parts of type that includes 'COPPER' and size 42 from the PART table.

Step 2: Find the minimum supply cost for these parts from the PARTSUPP table.

Step 3: Identify suppliers who supply these parts at the minimum supply cost by joining PARTSUPP and the result from Step 2.

Step 4: Select suppliers located in the region 'EUROPE' by joining SUPPLIER, NATION, and REGION tables and filtering by R_NAME = 'EUROPE'.

Step 5: Combine the above results to get suppliers in 'EUROPE' who supply the required parts at minimum cost.

Step 6: Order these suppliers by highest account balance (S_ACCTBAL) and select the top 100.
'full_sql_query': WITH Qualified_Parts AS (
    SELECT P_PARTKEY
    FROM PART
    WHERE P_TYPE LIKE '%COPPER%' AND P_SIZE = 42





In [40]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    try:
        pred_result = database.execute(data['full_sql_query'])
        print('pred_result\n', pred_result)
        gold_result = database.execute(data['gold_sql'])
        print('gold_result\n', gold_result)
        error_info = None
    
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

An error occurred: Execution failed on sql 'WITH Qualified_Parts AS (
    SELECT P_PARTKEY
    FROM PART
    WHERE P_TYPE LIKE '%COPPER%' AND P_SIZE = 42
),
Min_Part_Supp_Cost AS (
    SELECT PS_PARTKEY, MIN(PS_SUPPLYCOST) AS MIN_SUPPLYCOST
    FROM PARTSUPP
    WHERE PS_PARTKEY IN (SELECT P_PARTKEY FROM Qualified_Parts)
    GROUP BY PS_PARTKEY
),
Min_Cost_Suppliers AS (
    SELECT PS_PARTKEY, PS_SUPPKEY
    FROM PARTSUPP
    INNER JOIN Min_Part_Supp_Cost ON PARTSUPP.PS_PARTKEY = Min_Part_Supp_Cost.PS_PARTKEY AND PARTSUPP.PS_SUPPLYCOST = Min_Part_Supp_Cost.MIN_SUPPLYCOST
),
Suppliers_In_Europe AS (
    SELECT S_SUPPKEY, S_NAME, S_ACCTBAL, S_ADDRESS, S_PHONE, S_COMMENT
    FROM SUPPLIER
    INNER JOIN NATION ON SUPPLIER.S_NATIONKEY = N_NATIONKEY
    INNER JOIN REGION ON NATION.N_REGIONKEY = REGION.R_REGIONKEY
    WHERE R_NAME = 'EUROPE'
)
SELECT DISTINCT S_SUPPKEY, S_NAME, S_ACCTBAL, S_ADDRESS, S_PHONE, S_COMMENT
FROM Suppliers_In_Europe
INNER JOIN Min_Cost_Suppliers ON Suppliers_In_Eur

100%|██████████| 2/2 [00:00<00:00,  5.71it/s]

gold_result
     S_ACCTBAL              S_NAME          N_NAME  P_PARTKEY          P_MFGR  \
0     9967.45  Supplier#000002302          FRANCE      69795  Manufacturer#5   
1     9925.04  Supplier#000003400         ROMANIA      73399  Manufacturer#1   
2     9915.38  Supplier#000006085         ROMANIA      31078  Manufacturer#2   
3     9915.38  Supplier#000006085         ROMANIA     198527  Manufacturer#5   
4     9828.21  Supplier#000000647  UNITED KINGDOM      23140  Manufacturer#3   
..        ...                 ...             ...        ...             ...   
95    7822.90  Supplier#000000674          FRANCE     180673  Manufacturer#3   
96    7814.84  Supplier#000004126         ROMANIA     181607  Manufacturer#2   
97    7814.08  Supplier#000006465  UNITED KINGDOM      33961  Manufacturer#4   
98    7814.08  Supplier#000006465  UNITED KINGDOM      68946  Manufacturer#3   
99    7814.08  Supplier#000006465  UNITED KINGDOM     198907  Manufacturer#2   

                          




## Q3: Shipping Priority Query

In [12]:
NLquery1 = "List 10 orders had not been shipped as of the date 1995-03-20. \
List orders in the HOUSEHOLD market and with the highest revenue."

gold_sql1 = """select
	l_orderkey,
	sum(l_extendedprice * (1 - l_discount)) as revenue,
	o_orderdate,
	o_shippriority
from
	customer,
	orders,
	lineitem
where
	c_mktsegment = 'HOUSEHOLD'
	and c_custkey = o_custkey
	and l_orderkey = o_orderkey
	and o_orderdate < '1995-03-20'
	and l_shipdate > '1995-03-20'
group by
	l_orderkey,
	o_orderdate,
	o_shippriority
order by
	revenue desc,
	o_orderdate
LIMIT 10;"""

NLquery2 = "Find me top 10 unshipped orders as of the date 1995-03-20 in the BUILDING customer segment. \
List orders in the highest revenue."

gold_sql2 = """select
	l_orderkey,
	sum(l_extendedprice * (1 - l_discount)) as revenue,
	o_orderdate,
	o_shippriority
from
	customer,
	orders,
	lineitem
where
	c_mktsegment = 'BUILDING'
	and c_custkey = o_custkey
	and l_orderkey = o_orderkey
	and o_orderdate < '1995-03-20'
	and l_shipdate > '1995-03-20'
group by
	l_orderkey,
	o_orderdate,
	o_shippriority
order by
	revenue desc,
	o_orderdate
LIMIT 10;"""

NLsamples = [{'question': NLquery1, 'sql': gold_sql1},
             {'question': NLquery2, 'sql': gold_sql2}]

In [13]:
database = SqliteDatabase(db_file=db_path)
database.execute(gold_sql1)

Unnamed: 0,L_ORDERKEY,revenue,O_ORDERDATE,O_SHIPPRIORITY
0,4994400,423834.7976,1995-03-09,0
1,5577601,407855.0202,1995-03-11,0
2,2160291,401149.7805,1995-03-18,0
3,2845094,401094.1393,1995-03-06,0
4,1902471,400497.3847,1995-03-01,0
5,2346242,392580.0394,1995-03-17,0
6,2529826,387365.156,1995-02-17,0
7,5881319,383377.6244,1995-03-13,0
8,5329575,374659.8572,1995-03-07,0
9,2906022,370116.1556,1995-02-27,0


## gpt-4o-mini

In [14]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))

100%|██████████| 2/2 [00:06<00:00,  3.11s/it]

'sql_idx': 0
'db_id': tpc-h
'question': List 10 orders had not been shipped as of the date 1995-03-20. List orders in the HOUSEHOLD market and with the highest revenue.
'rationale': 1. We need to find orders that have not been shipped as of a specific date (1995-03-20). This means we will filter orders where the order status is not 'SHIPPED'.
2. We also need to ensure that the orders were placed on or before the specified date, so we will add a condition for the order date.
3. Additionally, we want to filter these orders to only include those from customers in the 'HOUSEHOLD' market segment. This requires a subquery to get customer keys from the CUSTOMER table where the market segment is 'HOUSEHOLD'.
4. Since we want to list the orders with the highest revenue, we will order the results by the total price in descending order.
5. Finally, we will limit the results to the top 10 orders.
'full_sql_query': SELECT O_ORDERKEY, O_TOTALPRICE 
FROM ORDERS 
WHERE O_ORDERSTATUS != 'SHIPPED' 
AND 




In [15]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    try:
        pred_result = database.execute(data['full_sql_query'])
        print('pred_result\n', pred_result)
        gold_result = database.execute(data['gold_sql'])
        print('gold_result\n', gold_result)
        error_info = None
    
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

pred_result
    O_ORDERKEY  O_TOTALPRICE
0     4722021     544089.09
1     4515876     510061.60
2     5893511     490806.51
3     4267751     485141.38
4     3234337     475530.92
5     2942469     469630.44
6      551136     469605.59
7     4165504     469071.40
8     3152929     468490.69
9     2224069     466840.78


 50%|█████     | 1/2 [00:01<00:01,  1.67s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     4994400  423834.7976  1995-03-09               0
1     5577601  407855.0202  1995-03-11               0
2     2160291  401149.7805  1995-03-18               0
3     2845094  401094.1393  1995-03-06               0
4     1902471  400497.3847  1995-03-01               0
5     2346242  392580.0394  1995-03-17               0
6     2529826  387365.1560  1995-02-17               0
7     5881319  383377.6244  1995-03-13               0
8     5329575  374659.8572  1995-03-07               0
9     2906022  370116.1556  1995-02-27               0
pred_result
 Empty DataFrame
Columns: [O_ORDERKEY, O_TOTALPRICE]
Index: []


100%|██████████| 2/2 [00:02<00:00,  1.49s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     3459808  405838.6989  1995-03-04               0
1      492164  390324.0610  1995-02-19               0
2     1188320  384537.9359  1995-03-09               0
3     1368674  379739.9997  1995-03-16               0
4      824738  379394.1262  1995-03-19               0
5     2435712  378673.0558  1995-02-26               0
6     4878020  378376.7952  1995-03-12               0
7     5521732  375153.9215  1995-03-13               0
8     2531012  373597.1154  1995-03-18               0
9     2628192  373133.3094  1995-02-22               0
{0: 0, 1: 0}
Final score: 0.0





## gpt-4o

In [16]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))

100%|██████████| 2/2 [00:08<00:00,  4.35s/it]

'sql_idx': 0
'db_id': tpc-h
'question': List 10 orders had not been shipped as of the date 1995-03-20. List orders in the HOUSEHOLD market and with the highest revenue.
'rationale': To solve this query, we need to follow these steps:
1. Identify orders that belong to customers in the 'HOUSEHOLD' market segment.
2. Filter these orders to include only those that had not been shipped as of the date '1995-03-20'.
3. Calculate the revenue for each order.
4. Sort the orders by revenue in descending order.
5. Limit the result to the top 10 orders.

We will need to join the CUSTOMER, ORDERS, and LINEITEM tables to achieve this.
'full_sql_query': SELECT O.O_ORDERKEY, O.O_TOTALPRICE, O.O_ORDERDATE
FROM ORDERS O
JOIN CUSTOMER C ON O.O_CUSTKEY = C.C_CUSTKEY
JOIN LINEITEM L ON O.O_ORDERKEY = L.L_ORDERKEY
WHERE C.C_MKTSEGMENT = 'HOUSEHOLD'
  AND L.L_SHIPDATE > '1995-03-20'
GROUP BY O.O_ORDERKEY, O.O_TOTALPRICE, O.O_ORDERDATE
ORDER BY SUM(L.L_EXTENDEDPRICE * (1 - L.L_DISCOUNT)) DESC
LIMIT 10;
'gold_s




In [17]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    try:
        pred_result = database.execute(data['full_sql_query'])
        print('pred_result\n', pred_result)
        gold_result = database.execute(data['gold_sql'])
        print('gold_result\n', gold_result)
        error_info = None
    
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

pred_result
    O_ORDERKEY  O_TOTALPRICE O_ORDERDATE
0     3967937     502906.33  1995-07-27
1     1395745     502742.76  1998-07-28
2     2199712     515531.82  1996-09-30
3       95808     492147.15  1995-10-10
4     2000131     485869.93  1997-12-21
5     2820355     487405.74  1995-10-10
6     3605638     486911.38  1998-07-26
7     1869860     480012.15  1998-04-28
8     1905157     483987.93  1996-11-14
9     2184164     477728.86  1997-09-25


 50%|█████     | 1/2 [00:03<00:03,  3.68s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     4994400  423834.7976  1995-03-09               0
1     5577601  407855.0202  1995-03-11               0
2     2160291  401149.7805  1995-03-18               0
3     2845094  401094.1393  1995-03-06               0
4     1902471  400497.3847  1995-03-01               0
5     2346242  392580.0394  1995-03-17               0
6     2529826  387365.1560  1995-02-17               0
7     5881319  383377.6244  1995-03-13               0
8     5329575  374659.8572  1995-03-07               0
9     2906022  370116.1556  1995-02-27               0
pred_result
    O_ORDERKEY O_ORDERDATE      REVENUE
0     3043270  1997-02-12  512788.6848
1     2232932  1997-04-13  511470.6210
2     4676257  1997-08-02  479250.9736
3     1672039  1997-11-17  471719.2877
4     2844870  1996-09-16  470162.6432
5     4206947  1996-07-18  462686.4295
6     2366755  1996-12-09  462021.3817
7     3456515  1998-03-26  461070.3114
8     1251844  199

100%|██████████| 2/2 [00:07<00:00,  3.59s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     3459808  405838.6989  1995-03-04               0
1      492164  390324.0610  1995-02-19               0
2     1188320  384537.9359  1995-03-09               0
3     1368674  379739.9997  1995-03-16               0
4      824738  379394.1262  1995-03-19               0
5     2435712  378673.0558  1995-02-26               0
6     4878020  378376.7952  1995-03-12               0
7     5521732  375153.9215  1995-03-13               0
8     2531012  373597.1154  1995-03-18               0
9     2628192  373133.3094  1995-02-22               0
{0: 0, 1: 0}
Final score: 0.0





## o1-preview

In [20]:
from openai import OpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_O1_KEY"))

template = '''### TASK
You are tasked with generating a SQLite query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQLite query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    formatted_prompt = prompt.format(schema=db_schema, input_query=x['question'])
    #print(formatted_prompt)
    
    response = client.chat.completions.create(
        model="o1-preview",
        messages=[
            {
                "role": "user", 
                "content": formatted_prompt
            }
        ]
    )
    response = response.choices[0].message.content
    from langchain_core.utils.json import parse_json_markdown
    output = parse_json_markdown(response)
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output['rationale']
    full_sql_output['full_sql_query'] = output['full_sql_query']
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)

for t in all_full_sql:
    for k, v in t.items():
        print("'{}': {}".format(k,v))


100%|██████████| 2/2 [01:29<00:00, 44.55s/it]

'sql_idx': 0
'db_id': tpc-h
'question': List 10 orders had not been shipped as of the date 1995-03-20. List orders in the HOUSEHOLD market and with the highest revenue.
'rationale': To list the top 10 orders that had not been shipped as of date '1995-03-20', in the HOUSEHOLD market and with the highest revenue, we perform the following steps:

1. **Join** the `ORDERS` table with the `CUSTOMER` table on `O_CUSTKEY = C_CUSTKEY` to get orders along with customer details.

2. **Filter** the customers to include only those in the HOUSEHOLD market segment by checking `C_MKTSEGMENT = 'HOUSEHOLD'`.

3. **Exclude** orders that have any line items shipped on or before '1995-03-20' by using a `NOT EXISTS` subquery on the `LINEITEM` table where `L_SHIPDATE <= '1995-03-20'`. This ensures we only get orders where none of the items have been shipped as of that date.

4. **Order** the results by `O_TOTALPRICE` in descending order to get the orders with the highest revenue.

5. **Limit** the results to




In [21]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    try:
        pred_result = database.execute(data['full_sql_query'])
        print('pred_result\n', pred_result)
        gold_result = database.execute(data['gold_sql'])
        print('gold_result\n', gold_result)
        error_info = None
    
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

pred_result
    O_ORDERKEY  O_TOTALPRICE
0     2199712     515531.82
1     3967937     502906.33
2     1395745     502742.76
3       95808     492147.15
4     2820355     487405.74
5     3605638     486911.38
6     2000131     485869.93
7     1905157     483987.93
8     3702855     481105.12
9     1869860     480012.15


 50%|█████     | 1/2 [00:03<00:03,  3.82s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     4994400  423834.7976  1995-03-09               0
1     5577601  407855.0202  1995-03-11               0
2     2160291  401149.7805  1995-03-18               0
3     2845094  401094.1393  1995-03-06               0
4     1902471  400497.3847  1995-03-01               0
5     2346242  392580.0394  1995-03-17               0
6     2529826  387365.1560  1995-02-17               0
7     5881319  383377.6244  1995-03-13               0
8     5329575  374659.8572  1995-03-07               0
9     2906022  370116.1556  1995-02-27               0
pred_result
    O_ORDERKEY  O_TOTALPRICE
0     4878020     400786.97
1      824738     400279.58
2     2435712     399571.28
3     1188320     398320.88
4     1368674     397552.93
5     5521732     396006.95
6     2300070     389990.64
7     2531012     387829.83
8     2628192     387472.25
9      993600     382481.78


100%|██████████| 2/2 [00:05<00:00,  2.94s/it]

gold_result
    L_ORDERKEY      revenue O_ORDERDATE  O_SHIPPRIORITY
0     3459808  405838.6989  1995-03-04               0
1      492164  390324.0610  1995-02-19               0
2     1188320  384537.9359  1995-03-09               0
3     1368674  379739.9997  1995-03-16               0
4      824738  379394.1262  1995-03-19               0
5     2435712  378673.0558  1995-02-26               0
6     4878020  378376.7952  1995-03-12               0
7     5521732  375153.9215  1995-03-13               0
8     2531012  373597.1154  1995-03-18               0
9     2628192  373133.3094  1995-02-22               0
{0: 0, 1: 0}
Final score: 0.0



