## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict, get_schema_str_with_tables
from src.database import SqliteDatabase, DuckDBDatabase
from src.sparc_preprocess import (
    load_sparc_data,
    process_all_tables, 
    filter_samples_by_count, 
    process_samples, 
    split_train_dev
)

# duckdb.sql('INSTALL sqlite')
# duckdb.sql('SET GLOBAL sqlite_all_varchar = true;')

proj_path = Path('.').resolve()
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

sparc_tables = process_all_tables(tables)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

Number of train: 3034 | Number of dev: 422


## Schema description

In [33]:
import sqlite3
db_path = 'data/tpch/TPC-H.db'
def get_database_schema(db_path, tables_list) -> str:

    stmt = ''

    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    # Fetch names of all tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cur.fetchall()

    # Fech create statements for all tables
    for table in tables:
        table_name = table[0]
        if tables_list and table_name not in tables_list:
            continue
        cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
        create_statement = cur.fetchone()[0]

        stmt += create_statement + '\n\n'

    conn.close()
    return stmt

print(get_database_schema(db_path=db_path, tables_list=[]))

CREATE TABLE NATION (
  N_NATIONKEY INTEGER PRIMARY KEY NOT NULL,
  N_NAME      TEXT NOT NULL,
  N_REGIONKEY INTEGER NOT NULL,
  N_COMMENT   TEXT,
  FOREIGN KEY (N_REGIONKEY) REFERENCES REGION(R_REGIONKEY)
)

CREATE TABLE REGION (
  R_REGIONKEY INTEGER PRIMARY KEY NOT NULL,
  R_NAME      TEXT NOT NULL,
  R_COMMENT   TEXT
)

CREATE TABLE PART (
  P_PARTKEY     INTEGER PRIMARY KEY NOT NULL,
  P_NAME        TEXT NOT NULL,
  P_MFGR        TEXT NOT NULL,
  P_BRAND       TEXT NOT NULL,
  P_TYPE        TEXT NOT NULL,
  P_SIZE        INTEGER NOT NULL,
  P_CONTAINER   TEXT NOT NULL,
  P_RETAILPRICE INTEGER NOT NULL,
  P_COMMENT     TEXT NOT NULL
)

CREATE TABLE SUPPLIER (
  S_SUPPKEY   INTEGER PRIMARY KEY NOT NULL,
  S_NAME      TEXT NOT NULL,
  S_ADDRESS   TEXT NOT NULL,
  S_NATIONKEY INTEGER NOT NULL,
  S_PHONE     TEXT NOT NULL,
  S_ACCTBAL   INTEGER NOT NULL,
  S_COMMENT   TEXT NOT NULL,
  FOREIGN KEY (S_NATIONKEY) REFERENCES NATION(N_NATIONKEY)
)

CREATE TABLE PARTSUPP (
  PS_PARTKEY    IN

In [34]:
NLquery1 = "Find top 100 suppliers with the highest account balance \
that I can place an order for parts of type includes COPPER, \
with size 42 in region EUROPE, and with the minimum cost supply."

gold_sql1 = """select
	s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment
from
	part, supplier, partsupp, nation, region
where
	p_partkey = ps_partkey
	and s_suppkey = ps_suppkey
	and p_size = 42
	and p_type like '%COPPER'
	and s_nationkey = n_nationkey
	and n_regionkey = r_regionkey
	and r_name = 'EUROPE'
	and ps_supplycost = (
		select
			min(ps_supplycost)
		from
			partsupp,
			supplier,
			nation,
			region
		where
			p_partkey = ps_partkey
			and s_suppkey = ps_suppkey
			and s_nationkey = n_nationkey
			and n_regionkey = r_regionkey
			and r_name = 'EUROPE'
	)
order by
	s_acctbal desc,
	n_name,
	s_name,
	p_partkey
LIMIT 100;"""

NLquery2 = "Find top 10 suppliers with the highest account balance \
that I can place an order in ASIA region for parts of type includes BRASS, \
with size 15, and with the minimum cost supply."

gold_sql2 = """select
	s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment
from
	part, supplier, partsupp, nation, region
where
	p_partkey = ps_partkey
	and s_suppkey = ps_suppkey
	and p_size = 15
	and p_type like '%BRASS'
	and s_nationkey = n_nationkey
	and n_regionkey = r_regionkey
	and r_name = 'ASIA'
	and ps_supplycost = (
		select
			min(ps_supplycost)
		from
			partsupp,
			supplier,
			nation,
			region
		where
			p_partkey = ps_partkey
			and s_suppkey = ps_suppkey
			and s_nationkey = n_nationkey
			and n_regionkey = r_regionkey
			and r_name = 'ASIA'
	)
order by
	s_acctbal desc,
	n_name,
	s_name,
	p_partkey
LIMIT 10;"""

NLsamples = [{'question': NLquery1, 'sql': gold_sql1},
             {'question': NLquery2, 'sql': gold_sql2}]

In [38]:
database = SqliteDatabase(db_file=db_path)
database.execute(gold_sql1)

Unnamed: 0,S_ACCTBAL,S_NAME,N_NAME,P_PARTKEY,P_MFGR,S_ADDRESS,S_PHONE,S_COMMENT
0,9967.45,Supplier#000002302,FRANCE,69795,Manufacturer#5,wMEzrsX2KKpTaJGE3uGEUibymG,16-486-165-5642,gly carefully bold deposits. accounts nag b
1,9925.04,Supplier#000003400,ROMANIA,73399,Manufacturer#1,IZSzKpRL1RNar39LvF,29-295-531-2833,unts along the ironic accounts must have to ha...
2,9915.38,Supplier#000006085,ROMANIA,31078,Manufacturer#2,T3Ju68MUhIb2hpTO3f8OGG,29-528-113-9241,iously bold sauternes. slyly regular asymptotes
3,9915.38,Supplier#000006085,ROMANIA,198527,Manufacturer#5,T3Ju68MUhIb2hpTO3f8OGG,29-528-113-9241,iously bold sauternes. slyly regular asymptotes
4,9828.21,Supplier#000000647,UNITED KINGDOM,23140,Manufacturer#3,x5U7MBZmwfG9,33-258-202-4782,s the slyly even ideas poach fluffily
...,...,...,...,...,...,...,...,...
95,7822.90,Supplier#000000674,FRANCE,180673,Manufacturer#3,jMxLRDxoP1Pf kzzyMVIfLB,16-128-338-8014,thely after the furiously even pains. quietly
96,7814.84,Supplier#000004126,ROMANIA,181607,Manufacturer#2,3s9EL2QxI5lAEeSPr9aDv0 O0X7SP PA4TQWAAYn,29-497-666-4765,fluffy packages. furiously ironic r
97,7814.08,Supplier#000006465,UNITED KINGDOM,33961,Manufacturer#4,"R0ofppl4Gkm,b,U5uCA0YL9wm3el luro0T",33-155-333-2168,ious dependencies. slyly regular depths doubt....
98,7814.08,Supplier#000006465,UNITED KINGDOM,68946,Manufacturer#3,"R0ofppl4Gkm,b,U5uCA0YL9wm3el luro0T",33-155-333-2168,ious dependencies. slyly regular depths doubt....


In [39]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

## Chain of Thought Prompt

In [41]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
all_full_sql

100%|██████████| 2/2 [00:08<00:00,  4.35s/it]


[{'sql_idx': 0,
  'db_id': 'tpc-h',
  'question': 'Find top 100 suppliers with the highest account balance that I can place an order for parts of type includes COPPER, with size 42 in region EUROPE, and with the minimum cost supply.',
  'rationale': "1. We need to find suppliers with the highest account balance, so we will select from the SUPPLIER table and include the account balance (S_ACCTBAL).  \n2. We need to filter suppliers based on the parts they supply, specifically those of type 'COPPER' and size 42. This requires joining the SUPPLIER table with the PARTSUPP and PART tables.  \n3. We also need to filter suppliers based on their region, which is 'EUROPE'. This requires joining the SUPPLIER table with the NATION table and then filtering based on the region.  \n4. We will order the results by account balance in descending order to get the top suppliers.  \n5. Finally, we limit the results to the top 100 suppliers.",
  'full_sql_query': "SELECT S.S_SUPPKEY, S.S_NAME, S.S_ACCTBAL 

In [42]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    pred_result = database.execute(data['full_sql_query'])
    print('pred_result\n', pred_result)
    gold_result = database.execute(data['gold_sql'])
    print('gold_result\n', gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

pred_result
     S_SUPPKEY              S_NAME  S_ACCTBAL
0        2302  Supplier#000002302    9967.45
1        3400  Supplier#000003400    9925.04
2        6085  Supplier#000006085    9915.38
3        6085  Supplier#000006085    9915.38
4         647  Supplier#000000647    9828.21
..        ...                 ...        ...
95       6646  Supplier#000006646    8337.57
96       3142  Supplier#000003142    8307.93
97       3142  Supplier#000003142    8307.93
98       7343  Supplier#000007343    8296.62
99       8584  Supplier#000008584    8280.18

[100 rows x 3 columns]


 50%|█████     | 1/2 [00:00<00:00,  1.71it/s]

gold_result
     S_ACCTBAL              S_NAME          N_NAME  P_PARTKEY          P_MFGR  \
0     9967.45  Supplier#000002302          FRANCE      69795  Manufacturer#5   
1     9925.04  Supplier#000003400         ROMANIA      73399  Manufacturer#1   
2     9915.38  Supplier#000006085         ROMANIA      31078  Manufacturer#2   
3     9915.38  Supplier#000006085         ROMANIA     198527  Manufacturer#5   
4     9828.21  Supplier#000000647  UNITED KINGDOM      23140  Manufacturer#3   
..        ...                 ...             ...        ...             ...   
95    7822.90  Supplier#000000674          FRANCE     180673  Manufacturer#3   
96    7814.84  Supplier#000004126         ROMANIA     181607  Manufacturer#2   
97    7814.08  Supplier#000006465  UNITED KINGDOM      33961  Manufacturer#4   
98    7814.08  Supplier#000006465  UNITED KINGDOM      68946  Manufacturer#3   
99    7814.08  Supplier#000006465  UNITED KINGDOM     198907  Manufacturer#2   

                          

100%|██████████| 2/2 [00:01<00:00,  1.92it/s]

gold_result
    S_ACCTBAL              S_NAME     N_NAME  P_PARTKEY          P_MFGR  \
0    9988.93  Supplier#000005449    VIETNAM      25448  Manufacturer#4   
1    9925.41  Supplier#000005391      CHINA     112879  Manufacturer#4   
2    9903.47  Supplier#000002334    VIETNAM      17331  Manufacturer#2   
3    9859.43  Supplier#000009403    VIETNAM      59402  Manufacturer#1   
4    9845.98  Supplier#000000175      JAPAN       5174  Manufacturer#1   
5    9809.13  Supplier#000002591      CHINA      70083  Manufacturer#3   
6    9704.66  Supplier#000008944      INDIA      66437  Manufacturer#4   
7    9694.06  Supplier#000004614      INDIA     159583  Manufacturer#5   
8    9681.99  Supplier#000000250  INDONESIA        249  Manufacturer#4   
9    9664.02  Supplier#000009995  INDONESIA      39994  Manufacturer#2   

                                  S_ADDRESS          S_PHONE  \
0  fhc8lUuZdqWUujcVaWogowEq1WVL9Y8m1efwCl3G  31-787-239-2170   
1              BfIsR LpIHomv77D0EU,T4x0VyZ4 




## Chain-of-Thought: gpt-4o

In [43]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
for idx in tqdm(range(len(NLsamples))):
    data = NLsamples[idx]
    x = data
    db_id = 'tpc-h'
    db_schema = get_database_schema(db_path=db_path, tables_list=[])
    input_data = {'schema': db_schema, 'input_query': x['question']}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x['question']
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x['sql']
    all_full_sql.append(full_sql_output)
all_full_sql

100%|██████████| 2/2 [00:07<00:00,  3.73s/it]


[{'sql_idx': 0,
  'db_id': 'tpc-h',
  'question': 'Find top 100 suppliers with the highest account balance that I can place an order for parts of type includes COPPER, with size 42 in region EUROPE, and with the minimum cost supply.',
  'rationale': "1. Identify the relevant tables: SUPPLIER, PART, PARTSUPP, NATION, REGION.\n2. Join SUPPLIER with NATION on S_NATIONKEY = N_NATIONKEY to get the nation information.\n3. Join NATION with REGION on N_REGIONKEY = R_REGIONKEY to filter suppliers in the EUROPE region.\n4. Join PARTSUPP with PART on PS_PARTKEY = P_PARTKEY to get part information.\n5. Filter parts of type including 'COPPER' and size 42.\n6. Join the result with SUPPLIER on PS_SUPPKEY = S_SUPPKEY to get supplier information.\n7. Filter suppliers based on the minimum supply cost.\n8. Order the suppliers by account balance in descending order.\n9. Limit the result to the top 100 suppliers.",
  'full_sql_query': "SELECT S_SUPPKEY, S_NAME, S_ACCTBAL, PS_SUPPLYCOST\nFROM SUPPLIER\nJOIN

In [44]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(db_file=db_path)
    pred_result = database.execute(data['full_sql_query'])
    print('pred_result\n', pred_result)
    gold_result = database.execute(data['gold_sql'])
    print('gold_result\n', gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


  0%|          | 0/2 [00:00<?, ?it/s]

pred_result
     S_SUPPKEY              S_NAME  S_ACCTBAL  PS_SUPPLYCOST
0        2302  Supplier#000002302    9967.45         321.56
1        3400  Supplier#000003400    9925.04         556.12
2        6085  Supplier#000006085    9915.38         487.57
3        6085  Supplier#000006085    9915.38         746.36
4         647  Supplier#000000647    9828.21         455.95
..        ...                 ...        ...            ...
95       6646  Supplier#000006646    8337.57         126.58
96       3142  Supplier#000003142    8307.93         548.85
97       3142  Supplier#000003142    8307.93         802.30
98       7343  Supplier#000007343    8296.62         210.42
99       8584  Supplier#000008584    8280.18         101.18

[100 rows x 4 columns]


 50%|█████     | 1/2 [00:00<00:00,  1.68it/s]

gold_result
     S_ACCTBAL              S_NAME          N_NAME  P_PARTKEY          P_MFGR  \
0     9967.45  Supplier#000002302          FRANCE      69795  Manufacturer#5   
1     9925.04  Supplier#000003400         ROMANIA      73399  Manufacturer#1   
2     9915.38  Supplier#000006085         ROMANIA      31078  Manufacturer#2   
3     9915.38  Supplier#000006085         ROMANIA     198527  Manufacturer#5   
4     9828.21  Supplier#000000647  UNITED KINGDOM      23140  Manufacturer#3   
..        ...                 ...             ...        ...             ...   
95    7822.90  Supplier#000000674          FRANCE     180673  Manufacturer#3   
96    7814.84  Supplier#000004126         ROMANIA     181607  Manufacturer#2   
97    7814.08  Supplier#000006465  UNITED KINGDOM      33961  Manufacturer#4   
98    7814.08  Supplier#000006465  UNITED KINGDOM      68946  Manufacturer#3   
99    7814.08  Supplier#000006465  UNITED KINGDOM     198907  Manufacturer#2   

                          

100%|██████████| 2/2 [00:01<00:00,  1.95it/s]

gold_result
    S_ACCTBAL              S_NAME     N_NAME  P_PARTKEY          P_MFGR  \
0    9988.93  Supplier#000005449    VIETNAM      25448  Manufacturer#4   
1    9925.41  Supplier#000005391      CHINA     112879  Manufacturer#4   
2    9903.47  Supplier#000002334    VIETNAM      17331  Manufacturer#2   
3    9859.43  Supplier#000009403    VIETNAM      59402  Manufacturer#1   
4    9845.98  Supplier#000000175      JAPAN       5174  Manufacturer#1   
5    9809.13  Supplier#000002591      CHINA      70083  Manufacturer#3   
6    9704.66  Supplier#000008944      INDIA      66437  Manufacturer#4   
7    9694.06  Supplier#000004614      INDIA     159583  Manufacturer#5   
8    9681.99  Supplier#000000250  INDONESIA        249  Manufacturer#4   
9    9664.02  Supplier#000009995  INDONESIA      39994  Manufacturer#2   

                                  S_ADDRESS          S_PHONE  \
0  fhc8lUuZdqWUujcVaWogowEq1WVL9Y8m1efwCl3G  31-787-239-2170   
1              BfIsR LpIHomv77D0EU,T4x0VyZ4 


