In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import sqlglot
import numpy as np
from sqlglot import expressions as exp
from src.parsing_sql import Schema, extract_all
from src.eval_utils import (
    partial_match, 
    compute_tsed
)

from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _extract_columns_from_expression,
    _determine_tag,
    _format_expression,
    _get_full_column_name,
    extract_aliases,
    extract_selection,
    extract_aggregation,
    extract_orderby,
    extract_others,
    
    _extract_aliases_from_select,
    _handle_table_or_subquery
)

# from warnings import catch_warnings, simplefilter
# #with warnings.catch_warnings(action='ignore'):
# with catch_warnings():
#     simplefilter("ignore")

In [3]:
schema_dict = {'PART': {'P_PARTKEY': 'integer',
  'P_NAME': 'text',
  'P_MFGR': 'text',
  'P_BRAND': 'text',
  'P_TYPE': 'text',
  'P_SIZE': 'text',
  'P_CONTAINER': 'text',
  'P_RETAILPRICE': 'text',
  'P_COMMENT': 'text'},
 'REGION': {'R_REGIONKEY': 'integer',
  'R_NAME': 'text',
  'R_COMMENT': 'text'},
 'NATION': {'N_NATIONKEY': 'integer',
  'N_NAME': 'text',
  'N_REGIONKEY': 'integer',
  'N_COMMENT': 'text'},
 'SUPPLIER': {'S_SUPPKEY': 'integer',
  'S_NAME': 'text',
  'S_ADDRESS': 'text',
  'S_NATIONKEY': 'integer',
  'S_PHONE': 'text',
  'S_ACCTBAL': 'text',
  'S_COMMENT': 'text'},
 'CUSTOMER': {'C_CUSTKEY': 'integer',
  'C_NAME': 'text',
  'C_ADDRESS': 'text',
  'C_NATIONKEY': 'integer',
  'C_PHONE': 'text',
  'C_ACCTBAL': 'integer',
  'C_MKTSEGMENT': 'text',
  'C_COMMENT': 'text'},
 'PARTSUPP': {'PS_PARTKEY': 'integer',
  'PS_SUPPKEY': 'integer',
  'PS_AVAILQTY': 'integer',
  'PS_SUPPLYCOST': 'integer',
  'PS_COMMENT': 'text'},
 'ORDERS': {'O_ORDERKEY': 'integer',
  'O_CUSTKEY': 'integer',
  'O_ORDERSTATUS': 'text',
  'O_TOTALPRICE': 'integer',
  'O_ORDERDATE': 'text',
  'O_ORDERPRIORITY': 'text',
  'O_CLERK': 'text',
  'O_SHIPPRIORITY': 'integer',
  'O_COMMENT': 'text'},
 'LINEITEM': {'L_ORDERKEY': 'integer',
  'L_PARTKEY': 'integer',
  'L_SUPPKEY': 'integer',
  'L_LINENUMBER': 'integer',
  'L_QUANTITY': 'integer',
  'L_EXTENDEDPRICE': 'integer',
  'L_DISCOUNT': 'integer',
  'L_TAX': 'integer',
  'L_RETURNFLAG': 'text',
  'L_LINESTATUS': 'text',
  'L_SHIPDATE': 'text',
  'L_COMMITDATE': 'text',
  'L_RECEIPTDATE': 'text',
  'L_SHIPINSTRUCT': 'text',
  'L_SHIPMODE': 'text',
  'L_COMMENT': 'text'}
}

def lowercase_json_keys(data):
    """
    Recursively lowercases all keys in a JSON-like dictionary.
    
    :param data: JSON object (dictionary or list)
    :return: JSON object with all keys in lowercase
    """
    if isinstance(data, dict):
        return {key.lower(): lowercase_json_keys(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [lowercase_json_keys(item) for item in data]
    else:
        return data

schema_dict = lowercase_json_keys(schema_dict)
print(schema_dict)
    
schema = Schema(schema_dict)


{'part': {'p_partkey': 'integer', 'p_name': 'text', 'p_mfgr': 'text', 'p_brand': 'text', 'p_type': 'text', 'p_size': 'text', 'p_container': 'text', 'p_retailprice': 'text', 'p_comment': 'text'}, 'region': {'r_regionkey': 'integer', 'r_name': 'text', 'r_comment': 'text'}, 'nation': {'n_nationkey': 'integer', 'n_name': 'text', 'n_regionkey': 'integer', 'n_comment': 'text'}, 'supplier': {'s_suppkey': 'integer', 's_name': 'text', 's_address': 'text', 's_nationkey': 'integer', 's_phone': 'text', 's_acctbal': 'text', 's_comment': 'text'}, 'customer': {'c_custkey': 'integer', 'c_name': 'text', 'c_address': 'text', 'c_nationkey': 'integer', 'c_phone': 'text', 'c_acctbal': 'integer', 'c_mktsegment': 'text', 'c_comment': 'text'}, 'partsupp': {'ps_partkey': 'integer', 'ps_suppkey': 'integer', 'ps_availqty': 'integer', 'ps_supplycost': 'integer', 'ps_comment': 'text'}, 'orders': {'o_orderkey': 'integer', 'o_custkey': 'integer', 'o_orderstatus': 'text', 'o_totalprice': 'integer', 'o_orderdate': 'te

# Measurement of Structural Similarity between Source and Target SQLs

In [7]:
import sqlglot

def compute_score(sql1, sql2):
    print(sql1)
    print(sql2)
    output1 = extract_all(sql1, schema)
    output2 = extract_all(sql2, schema)

    formatted_sql1 = output1['subqueries'][0]
    formatted_sql2 = output2['subqueries'][0]
    tsed, distance = compute_tsed(formatted_sql1, formatted_sql2, build_type='apted')  # apted or zss
    print('[SQL1]\n', formatted_sql1.sql(pretty=True))
    print()
    print('[SQL2]\n', formatted_sql2.sql(pretty=True))
    print()
    print(f'TSED: {tsed:.4f}')
    print(f'Tree Edit Distance: {distance}')

    # partial match
    print('Partial Match Score')
    from src.eval_utils import get_all_partial_score

    results, overall_score = get_all_partial_score(
        output1, output2, 
        build_type='apted', 
        criteria='tsed', 
        penalty=0.01, 
        use_bert=True, 
        rescale_with_baseline=True)
    for k, v in results.items():
        if v[-1]:
            print(f'- {k}: {v[-1]:.4f}')
        else:
            print(f'- {k}: {v[-1]}')
    print('Average score')
    for k, v in overall_score.items():
        print(f'- {k}: {v:.4f}')
    return overall_score['structural'], overall_score['semantic'], overall_score['overall']

### Read pred SQL, gold SQL

In [8]:
import os
import json
question_path = 'data/tpch/questions_new.json'
test_queries = ['q12', 'q13', 'q14', 'q15', 'q16', 'q17', 'q18', 'q19', 'q20', 'q21']

def read_results(result_path):
    with open(question_path, 'r') as f:
        questions = json.load(f)

    for question in questions:
        ref_id = question['ref_id']
        if ref_id not in test_queries:
            continue
        with open(os.path.join(result_path, question['instance_id'])+"_gold.sql", 'r') as f:
            question['gold_sql'] = f.read()
        with open(os.path.join(result_path, question['instance_id'])+"_pred.sql", 'r') as f:
            question['pred_sql'] = f.read()

    NLsamples = []
    for i in range(len(questions)):
        q = questions[i]
        ref_id = q['ref_id']
        if ref_id not in test_queries:
            continue
        NLsamples.append({'instance_id': q['instance_id'],
                          'ref_id': q['ref_id'],
                          'question': q['question'], 
                          'gold_sql': q['gold_sql'],
                          'pred_sql': q['pred_sql'],
                         })        

    return NLsamples

def print_scores(result_path):
    scores = []
    NLsamples = read_results(result_path)
    for sample in NLsamples:
        sql1 = sample['pred_sql']
        sql2 = sample['gold_sql']
        tsed, distance, overall_score = compute_score(sql1, sql2)
        scores.append([tsed, distance, overall_score])
    scores = np.array(scores)
    print(np.mean(scores, axis=0))

### GPT-4o

In [9]:
result_path = 'data/tpch/outputs/gpt-4o_0_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_line_items_count FROM LINEITEM WHERE L_RECEIPTDATE > L_COMMITDATE AND strftime('%Y', L_RECEIPTDATE) = '1993';
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '1993'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2432
Tree Edit Distance: 28
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.2956
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5800
- semantic: 0.6226
- overall: 0.5969
SELECT T1.C_CUST

- table_asts: 0.7475
- sel_asts: 0.2751
- cond_asts: 0.3412
- agg_asts: 0.4873
- orderby_asts: None
- subqueries: 0.7812
- distinct: 1.0000
- limit: None
Average score
- structural: 0.5349
- semantic: 0.7981
- overall: 0.6054
SELECT SUM(L.L_EXTENDEDPRICE * (1 - L.L_DISCOUNT)) AS total_discounted_revenue
FROM LINEITEM L
JOIN PART P ON L.L_PARTKEY = P.P_PARTKEY
WHERE L.L_SHIPMODE = 'AIR' AND P.P_BRAND = 'Brand#12';
select
	sum(l_extendedprice* (1 - l_discount)) as revenue
from
	lineitem,
	part
where
	(
		p_partkey = l_partkey
        and p_brand = 'Brand#12'
		and l_shipmode in ('AIR', 'AIR REG')
	);
[SQL1]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem
JOIN PART AS P
  ON L.L_PARTKEY = P.P_PARTKEY
WHERE
  lineitem.l_shipmode = 'AIR' AND part.p_brand = 'Brand#12'

[SQL2]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem, part
WHERE
  (
    part.p_partkey = lineitem.l_partkey
    AND part.p_brand = 'Brand#1

In [10]:
result_path = 'data/tpch/outputs/gpt-4o_2_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_line_items_count
FROM LINEITEM
WHERE L_RECEIPTDATE > L_COMMITDATE
  AND strftime('%Y', L_RECEIPTDATE) = '1993';
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '1993'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2432
Tree Edit Distance: 28
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.2956
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5800
- semantic: 0.6226
- overall: 0.5969
SELECT T1.C_CU

- table_asts: 0.6633
- sel_asts: 0.1331
- cond_asts: 0.3267
- agg_asts: 0.0
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.2726
- semantic: 0.2970
- overall: 0.2808
SELECT COUNT(DISTINCT C.C_CUSTKEY) AS large_volume_customers_count
FROM CUSTOMER C
JOIN ORDERS O ON C.C_CUSTKEY = O.O_CUSTKEY
JOIN LINEITEM L ON O.O_ORDERKEY = L.L_ORDERKEY
GROUP BY C.C_CUSTKEY
HAVING SUM(L.L_QUANTITY) > 300;
select
	count(distinct c_custkey)
from
	customer,
	orders,
	lineitem
where
	o_orderkey in (
		select
			l_orderkey
		from
			lineitem
		group by
			l_orderkey having
				sum(l_quantity) > 300
	)
	and c_custkey = o_custkey
	and o_orderkey = l_orderkey;
[SQL1]
 SELECT
  COUNT(DISTINCT C.C_CUSTKEY)
FROM customer
JOIN ORDERS AS O
  ON C.C_CUSTKEY = O.O_CUSTKEY
JOIN LINEITEM AS L
  ON O.O_ORDERKEY = L.L_ORDERKEY
GROUP BY
  customer.c_custkey
HAVING
  SUM(lineitem.l_quantity) > 300

[SQL2]
 SELECT
  COUNT(DISTINCT c_custkey)
FROM customer, orders, lineite

In [11]:
result_path = 'data/tpch/outputs/gpt-4o_3_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_line_items_count
FROM LINEITEM
WHERE L_RECEIPTDATE > L_COMMITDATE
  AND L_RECEIPTDATE BETWEEN '1993-01-01' AND '1993-12-31';
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND lineitem.l_receiptdate BETWEEN '1993-01-01' AND '1993-12-31'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2973
Tree Edit Distance: 26
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.3204
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5879
- semantic: 0.6305
- over

- table_asts: 0.4539
- sel_asts: 0.6178
- cond_asts: 0.2425
- agg_asts: 1.0000
- orderby_asts: None
- subqueries: 1.0000
- distinct: 1.0000
- limit: None
Average score
- structural: 0.7088
- semantic: 0.7432
- overall: 0.7190
SELECT SUM(L_EXTENDEDPRICE * (1 - L_DISCOUNT)) AS total_discounted_revenue FROM LINEITEM AS T1 JOIN PART AS T2 ON T1.L_PARTKEY = T2.P_PARTKEY WHERE T1.L_SHIPMODE = 'AIR' AND T2.P_BRAND = 'Brand#12';
select
	sum(l_extendedprice* (1 - l_discount)) as revenue
from
	lineitem,
	part
where
	(
		p_partkey = l_partkey
        and p_brand = 'Brand#12'
		and l_shipmode in ('AIR', 'AIR REG')
	);
[SQL1]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem
JOIN PART AS T2
  ON T1.L_PARTKEY = T2.P_PARTKEY
WHERE
  lineitem.l_shipmode = 'AIR' AND part.p_brand = 'Brand#12'

[SQL2]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem, part
WHERE
  (
    part.p_partkey = lineitem.l_partkey
    AND part.p_brand

In [12]:
result_path = 'data/tpch/outputs/gpt-4o_1_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_lineitem_count FROM LINEITEM WHERE L_RECEIPTDATE > L_COMMITDATE AND strftime('%Y', L_RECEIPTDATE) = '1993';
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '1993'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2432
Tree Edit Distance: 28
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.2956
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5800
- semantic: 0.6226
- overall: 0.5969
SELECT C.C_CUSTKEY

- table_asts: 1.0000
- sel_asts: 0.1331
- cond_asts: 0.3267
- agg_asts: 0.0
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.3567
- semantic: 0.3812
- overall: 0.3649
SELECT COUNT(DISTINCT O_CUSTKEY) AS large_volume_customers_count
FROM ORDERS
WHERE O_ORDERKEY IN (
    SELECT L_ORDERKEY
    FROM LINEITEM
    GROUP BY L_ORDERKEY
    HAVING SUM(L_QUANTITY) > 300
);
select
	count(distinct c_custkey)
from
	customer,
	orders,
	lineitem
where
	o_orderkey in (
		select
			l_orderkey
		from
			lineitem
		group by
			l_orderkey having
				sum(l_quantity) > 300
	)
	and c_custkey = o_custkey
	and o_orderkey = l_orderkey;
[SQL1]
 SELECT
  COUNT(DISTINCT O_CUSTKEY)
FROM orders
WHERE
  orders.o_orderkey IN (
    SELECT
      lineitem.l_orderkey
    FROM LINEITEM
    GROUP BY
      L_ORDERKEY
    HAVING
      SUM(L_QUANTITY) > 300
  )

[SQL2]
 SELECT
  COUNT(DISTINCT c_custkey)
FROM customer, orders, lineitem
WHERE
  orders.o_orderkey IN (
    SELEC

### GPT-4o-mini

In [13]:
result_path = 'data/tpch/outputs/gpt-4o-mini_0_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_line_items_count FROM LINEITEM T1 WHERE T1.L_RECEIPTDATE > T1.L_SHIPDATE AND strftime('%Y', T1.L_RECEIPTDATE) = '1993'
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_shipdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '1993'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2432
Tree Edit Distance: 28
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.2914
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5800
- semantic: 0.6182
- overall: 0.5955
SELECT C.

- table_asts: 0.4950
- sel_asts: -0.0102
- cond_asts: 0.2425
- agg_asts: 0.0
- orderby_asts: None
- subqueries: 0.0
- distinct: 1.0000
- limit: None
Average score
- structural: 0.2888
- semantic: 0.4165
- overall: 0.2879
SELECT SUM(L_EXTENDEDPRICE * (1 - L_DISCOUNT)) AS total_discounted_revenue
FROM LINEITEM T1
JOIN PARTSUPP T2 ON T1.L_PARTKEY = T2.PS_PARTKEY
JOIN PART T3 ON T2.PS_PARTKEY = T3.P_PARTKEY
JOIN ORDERS T4 ON T1.L_ORDERKEY = T4.O_ORDERKEY
WHERE T3.P_BRAND = 'Brand#12' AND T1.L_SHIPMODE = 'AIR'
select
	sum(l_extendedprice* (1 - l_discount)) as revenue
from
	lineitem,
	part
where
	(
		p_partkey = l_partkey
        and p_brand = 'Brand#12'
		and l_shipmode in ('AIR', 'AIR REG')
	);
[SQL1]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem
JOIN PARTSUPP AS T2
  ON T1.L_PARTKEY = T2.PS_PARTKEY
JOIN PART AS T3
  ON T2.PS_PARTKEY = T3.P_PARTKEY
JOIN ORDERS AS T4
  ON T1.L_ORDERKEY = T4.O_ORDERKEY
WHERE
  part.p_brand = 'Brand#12' AND lineitem

In [14]:
result_path = 'data/tpch/outputs/gpt-4o-mini_2_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS order_count
FROM ORDERS T1
WHERE T1.O_ORDERDATE >= '1993-01-01'
AND T1.O_ORDERDATE < '1993-04-01'
AND EXISTS (
    SELECT *
    FROM LINEITEM T2
    WHERE T2.L_ORDERKEY = T1.O_ORDERKEY
    AND T2.L_COMMITDATE < T2.L_RECEIPTDATE
);
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM orders
WHERE
  orders.o_orderdate >= '1993-01-01'
  AND orders.o_orderdate < '1993-04-01'
  AND EXISTS(
    SELECT
      *
    FROM lineitem
    WHERE
      lineitem.l_orderkey = orders.o_orderkey
      AND lineitem.l_commitdate < lineitem.l_receiptdate
  )

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.0488
Tree Edit Distance: 39


- table_asts: 0.6633
- sel_asts: 0.1480
- cond_asts: 0.3267
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.3651
- semantic: 0.4131
- overall: 0.3793
SELECT COUNT(DISTINCT c.C_CUSTKEY) AS customer_count
FROM CUSTOMER c
JOIN ORDERS o ON c.C_CUSTKEY = o.O_CUSTKEY
JOIN LINEITEM l ON o.O_ORDERKEY = l.L_ORDERKEY
GROUP BY c.C_CUSTKEY
HAVING SUM(l.L_QUANTITY) > 300;
select
	count(distinct c_custkey)
from
	customer,
	orders,
	lineitem
where
	o_orderkey in (
		select
			l_orderkey
		from
			lineitem
		group by
			l_orderkey having
				sum(l_quantity) > 300
	)
	and c_custkey = o_custkey
	and o_orderkey = l_orderkey;
[SQL1]
 SELECT
  COUNT(DISTINCT c.C_CUSTKEY)
FROM customer
JOIN ORDERS AS o
  ON c.C_CUSTKEY = o.O_CUSTKEY
JOIN LINEITEM AS l
  ON o.O_ORDERKEY = l.L_ORDERKEY
GROUP BY
  customer.c_custkey
HAVING
  SUM(lineitem.l_quantity) > 300

[SQL2]
 SELECT
  COUNT(DISTINCT c_custkey)
FROM customer, orders, lineitem
WHERE
  ord

In [14]:
result_path = 'data/tpch/outputs/gpt-4o-mini_3_all/results/'
print_scores(result_path)

select count(*) as order_count
from orders o
where o.o_orderdate >= '1993-01-01'
  and o.o_orderdate < '1994-01-01'
  and exists (
    select *
    from lineitem l
    where l.l_orderkey = o.o_orderkey
      and l.l_commitdate < l.l_receiptdate
  );
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM orders
WHERE
  orders.o_orderdate >= '[placeholder-type:string]'
  AND orders.o_orderdate < '[placeholder-type:string]'
  AND EXISTS(
    SELECT
      *
    FROM lineitem
    WHERE
      lineitem.l_orderkey = orders.o_orderkey
      AND lineitem.l_commitdate < lineitem.l_receiptdate
  )

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[pla

- table_asts: 0.7475
- sel_asts: -0.0101
- cond_asts: 0.1753
- agg_asts: 0.0
- orderby_asts: None
- subqueries: 0.0
- distinct: 1.0000
- limit: None
Overall Score: 0.3188
SELECT SUM(L.L_EXTENDEDPRICE * L.L_DISCOUNT) AS revenue
FROM LINEITEM AS L
JOIN PARTSUPP AS PS ON L.L_PARTKEY = PS.PS_PARTKEY
JOIN PART AS P ON PS.PS_PARTKEY = P.P_PARTKEY
JOIN ORDERS AS O ON L.L_ORDERKEY = O.O_ORDERKEY
WHERE L.L_SHIPMODE = 'air' AND P.P_BRAND = 'Brand#12'
select
	sum(l_extendedprice* (1 - l_discount)) as revenue
from
	lineitem,
	part
where
	(
		p_partkey = l_partkey
        and p_brand = 'Brand#12'
		and l_shipmode in ('AIR', 'AIR REG')
	);
[SQL1]
 SELECT
  SUM(lineitem.l_extendedprice * lineitem.l_discount)
FROM lineitem
JOIN partsupp
  ON L.L_PARTKEY = PS.PS_PARTKEY
JOIN part
  ON PS.PS_PARTKEY = P.P_PARTKEY
JOIN orders
  ON L.L_ORDERKEY = O.O_ORDERKEY
WHERE
  lineitem.l_shipmode = '[placeholder-type:string]'
  AND part.p_brand = '[placeholder-type:string]'

[SQL2]
 SELECT
  SUM(lineitem.l_extended

In [15]:
result_path = 'data/tpch/outputs/gpt-4o-mini_1_all/results/'
print_scores(result_path)

SELECT COUNT(*) AS late_lineitem_count FROM LINEITEM T1 WHERE strftime('%Y', T1.L_RECEIPTDATE) = '1993' AND T1.L_COMMITDATE < T1.L_RECEIPTDATE;
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  STRFTIME('%Y', lineitem.l_receiptdate) = '1993'
  AND lineitem.l_commitdate < lineitem.l_receiptdate

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2703
Tree Edit Distance: 27
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.3920
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.6276
- semantic: 0.6304
- overall: 0.6290
SELECT

- table_asts: 1.0000
- sel_asts: 0.1331
- cond_asts: 0.3267
- agg_asts: 0.0
- orderby_asts: 0.0
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.2854
- semantic: 0.3049
- overall: 0.2920
SELECT COUNT(DISTINCT C.C_CUSTKEY) AS CustomerCount
FROM CUSTOMER C
JOIN ORDERS O ON C.C_CUSTKEY = O.O_CUSTKEY
JOIN LINEITEM L ON O.O_ORDERKEY = L.L_ORDERKEY
GROUP BY C.C_CUSTKEY
HAVING SUM(L.L_QUANTITY) > 300;
select
	count(distinct c_custkey)
from
	customer,
	orders,
	lineitem
where
	o_orderkey in (
		select
			l_orderkey
		from
			lineitem
		group by
			l_orderkey having
				sum(l_quantity) > 300
	)
	and c_custkey = o_custkey
	and o_orderkey = l_orderkey;
[SQL1]
 SELECT
  COUNT(DISTINCT C.C_CUSTKEY)
FROM customer
JOIN ORDERS AS O
  ON C.C_CUSTKEY = O.O_CUSTKEY
JOIN LINEITEM AS L
  ON O.O_ORDERKEY = L.L_ORDERKEY
GROUP BY
  customer.c_custkey
HAVING
  SUM(lineitem.l_quantity) > 300

[SQL2]
 SELECT
  COUNT(DISTINCT c_custkey)
FROM customer, orders, lineitem
WHERE
  orders

## Model-based method

In [22]:
import os
import json

question_path = 'data/tpch/questions_new.json'
root_path = 'data/tpch/gold_new/'
test_queries = ['q12', 'q13', 'q14', 'q15', 'q16', 'q17', 'q18', 'q19', 'q20', 'q21']

def read_results(result_path):
    with open(question_path, 'r') as f:
        questions = json.load(f)
    
    with open(result_path, 'r') as f:
        results = json.load(f)

    for question in questions:
        ref_id = question['ref_id']
        if ref_id not in test_queries:
            continue
        with open(os.path.join(root_path, question['instance_id'])+".sql", 'r') as f:
            question['gold_sql'] = f.read()
        question['pred_sql'] = results[ref_id]

    NLsamples = []
    for i in range(len(questions)):
        q = questions[i]
        ref_id = q['ref_id']
        if ref_id not in test_queries:
            continue
        NLsamples.append({'instance_id': q['instance_id'],
                          'ref_id': q['ref_id'],
                          'question': q['question'], 
                          'gold_sql': q['gold_sql'],
                          'pred_sql': q['pred_sql'],
                         })        
    return NLsamples

def print_scores(result_path):
    scores = []
    NLsamples = read_results(result_path)
    for sample in NLsamples:
        sql1 = sample['pred_sql']
        sql2 = sample['gold_sql']
        tsed, distance, overall_score = compute_score(sql1, sql2)
        scores.append([tsed, distance, overall_score])
    scores = np.array(scores)
    print(np.mean(scores, axis=0))
    


In [23]:
result_path = '../CHESS/results/dev/gpt-4o-mini/tpch_new_0/2025-01-07-11-19-33/-candidate_generation.json'
print_scores(result_path)

SELECT COUNT(*) FROM LINEITEM WHERE L_RECEIPTDATE > L_COMMITDATE AND strftime('%Y', L_RECEIPTDATE) = '1993';
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '1993'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2432
Tree Edit Distance: 28
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.2956
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.5800
- semantic: 0.6226
- overall: 0.5969
SELECT COUNT(O.O_ORDERKEY) AS order_count

- table_asts: 0.4950
- sel_asts: -0.0102
- cond_asts: 0.2425
- agg_asts: 0.0
- orderby_asts: None
- subqueries: 0.0
- distinct: 1.0000
- limit: None
Average score
- structural: 0.2888
- semantic: 0.4123
- overall: 0.2879
SELECT SUM(L_EXTENDEDPRICE * (1 - L_DISCOUNT / 100)) AS total_discounted_revenue FROM LINEITEM T1 JOIN PART T2 ON T1.L_PARTKEY = T2.P_PARTKEY WHERE T1.L_SHIPMODE = 'air' AND T2.P_BRAND = 'Brand#12';
select
	sum(l_extendedprice* (1 - l_discount)) as revenue
from
	lineitem,
	part
where
	(
		p_partkey = l_partkey
        and p_brand = 'Brand#12'
		and l_shipmode in ('AIR', 'AIR REG')
	);
[SQL1]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount / 100
  ))
FROM lineitem
JOIN PART AS T2
  ON T1.L_PARTKEY = T2.P_PARTKEY
WHERE
  lineitem.l_shipmode = 'air' AND part.p_brand = 'Brand#12'

[SQL2]
 SELECT
  SUM(lineitem.l_extendedprice * (
    1 - lineitem.l_discount
  ))
FROM lineitem, part
WHERE
  (
    part.p_partkey = lineitem.l_partkey
    AND part.p_bran

In [24]:
result_path = '../CHESS/results/dev/gpt-4o-mini/tpch_new_1/2025-01-07-11-15-00/-candidate_generation.json'
print_scores(result_path)

select count(*) as late_lineitem_count from LINEITEM where year(L_RECEIPTDATE) = 1993 and L_COMMITDATE < L_RECEIPTDATE;
select
	count(*) as count
from
	orders,
	lineitem
where
	o_orderkey = l_orderkey
	and l_commitdate < l_receiptdate
	and l_receiptdate >= '1993-01-01'
	and l_receiptdate < '1994-01-01';
[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  YEAR(lineitem.l_receiptdate) = 1993
  AND lineitem.l_commitdate < lineitem.l_receiptdate

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'

TSED: 0.2703
Tree Edit Distance: 27
Partial Match Score
- table_asts: 0.4950
- sel_asts: 1.0000
- cond_asts: 0.3845
- agg_asts: None
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.6217
- semantic: 0.6318
- overall: 0.6265
SELECT c.c_custkey, COUNT(o.o_orderkey) AS

- table_asts: 1.0000
- sel_asts: 0.1331
- cond_asts: 0.3267
- agg_asts: 0.0
- orderby_asts: None
- subqueries: None
- distinct: None
- limit: None
Average score
- structural: 0.3567
- semantic: 0.3812
- overall: 0.3649
SELECT COUNT(DISTINCT T1.O_CUSTKEY) AS large_volume_customers_count FROM ORDERS AS T1 JOIN LINEITEM AS T2 ON T1.O_ORDERKEY = T2.L_ORDERKEY GROUP BY T1.O_ORDERKEY HAVING SUM(T2.L_QUANTITY) > 300;
select
	count(distinct c_custkey)
from
	customer,
	orders,
	lineitem
where
	o_orderkey in (
		select
			l_orderkey
		from
			lineitem
		group by
			l_orderkey having
				sum(l_quantity) > 300
	)
	and c_custkey = o_custkey
	and o_orderkey = l_orderkey;
[SQL1]
 SELECT
  COUNT(DISTINCT T1.O_CUSTKEY)
FROM orders
JOIN LINEITEM AS T2
  ON T1.O_ORDERKEY = T2.L_ORDERKEY
GROUP BY
  orders.o_orderkey
HAVING
  SUM(lineitem.l_quantity) > 300

[SQL2]
 SELECT
  COUNT(DISTINCT c_custkey)
FROM customer, orders, lineitem
WHERE
  orders.o_orderkey IN (
    SELECT
      lineitem.l_orderkey
    FROM