In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import sqlglot
import numpy as np
from sqlglot import expressions as exp
from src.parsing_sql import Schema, extract_all
from src.eval_utils import (
    partial_match, 
    compute_tsed
)

In [3]:
schema_dict = {'PART': {'P_PARTKEY': 'integer',
  'P_NAME': 'text',
  'P_MFGR': 'text',
  'P_BRAND': 'text',
  'P_TYPE': 'text',
  'P_SIZE': 'text',
  'P_CONTAINER': 'text',
  'P_RETAILPRICE': 'text',
  'P_COMMENT': 'text'},
 'REGION': {'R_REGIONKEY': 'integer',
  'R_NAME': 'text',
  'R_COMMENT': 'text'},
 'NATION': {'N_NATIONKEY': 'integer',
  'N_NAME': 'text',
  'N_REGIONKEY': 'integer',
  'N_COMMENT': 'text'},
 'SUPPLIER': {'S_SUPPKEY': 'integer',
  'S_NAME': 'text',
  'S_ADDRESS': 'text',
  'S_NATIONKEY': 'integer',
  'S_PHONE': 'text',
  'S_ACCTBAL': 'text',
  'S_COMMENT': 'text'},
 'CUSTOMER': {'C_CUSTKEY': 'integer',
  'C_NAME': 'text',
  'C_ADDRESS': 'text',
  'C_NATIONKEY': 'integer',
  'C_PHONE': 'text',
  'C_ACCTBAL': 'integer',
  'C_MKTSEGMENT': 'text',
  'C_COMMENT': 'text'},
 'PARTSUPP': {'PS_PARTKEY': 'integer',
  'PS_SUPPKEY': 'integer',
  'PS_AVAILQTY': 'integer',
  'PS_SUPPLYCOST': 'integer',
  'PS_COMMENT': 'text'},
 'ORDERS': {'O_ORDERKEY': 'integer',
  'O_CUSTKEY': 'integer',
  'O_ORDERSTATUS': 'text',
  'O_TOTALPRICE': 'integer',
  'O_ORDERDATE': 'text',
  'O_ORDERPRIORITY': 'text',
  'O_CLERK': 'text',
  'O_SHIPPRIORITY': 'integer',
  'O_COMMENT': 'text'},
 'LINEITEM': {'L_ORDERKEY': 'integer',
  'L_PARTKEY': 'integer',
  'L_SUPPKEY': 'integer',
  'L_LINENUMBER': 'integer',
  'L_QUANTITY': 'integer',
  'L_EXTENDEDPRICE': 'integer',
  'L_DISCOUNT': 'integer',
  'L_TAX': 'integer',
  'L_RETURNFLAG': 'text',
  'L_LINESTATUS': 'text',
  'L_SHIPDATE': 'text',
  'L_COMMITDATE': 'text',
  'L_RECEIPTDATE': 'text',
  'L_SHIPINSTRUCT': 'text',
  'L_SHIPMODE': 'text',
  'L_COMMENT': 'text'}
}

def lowercase_json_keys(data):
    """
    Recursively lowercases all keys in a JSON-like dictionary.
    
    :param data: JSON object (dictionary or list)
    :return: JSON object with all keys in lowercase
    """
    if isinstance(data, dict):
        return {key.lower(): lowercase_json_keys(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [lowercase_json_keys(item) for item in data]
    else:
        return data

schema_dict = lowercase_json_keys(schema_dict)
print(schema_dict)
    
schema = Schema(schema_dict)


{'part': {'p_partkey': 'integer', 'p_name': 'text', 'p_mfgr': 'text', 'p_brand': 'text', 'p_type': 'text', 'p_size': 'text', 'p_container': 'text', 'p_retailprice': 'text', 'p_comment': 'text'}, 'region': {'r_regionkey': 'integer', 'r_name': 'text', 'r_comment': 'text'}, 'nation': {'n_nationkey': 'integer', 'n_name': 'text', 'n_regionkey': 'integer', 'n_comment': 'text'}, 'supplier': {'s_suppkey': 'integer', 's_name': 'text', 's_address': 'text', 's_nationkey': 'integer', 's_phone': 'text', 's_acctbal': 'text', 's_comment': 'text'}, 'customer': {'c_custkey': 'integer', 'c_name': 'text', 'c_address': 'text', 'c_nationkey': 'integer', 'c_phone': 'text', 'c_acctbal': 'integer', 'c_mktsegment': 'text', 'c_comment': 'text'}, 'partsupp': {'ps_partkey': 'integer', 'ps_suppkey': 'integer', 'ps_availqty': 'integer', 'ps_supplycost': 'integer', 'ps_comment': 'text'}, 'orders': {'o_orderkey': 'integer', 'o_custkey': 'integer', 'o_orderstatus': 'text', 'o_totalprice': 'integer', 'o_orderdate': 'te

In [4]:
from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _extract_columns_from_expression,
    _determine_tag,
    _format_expression,
    _get_full_column_name,
    extract_aliases,
    extract_selection,
    extract_aggregation,
    extract_orderby,
    extract_others,
    
    _extract_aliases_from_select,
    _handle_table_or_subquery
)

In [5]:
import sqlglot

def matching_score(sql1, sql2):
    #sql1 = sql1.lower()
    #sql2 = sql2.lower()
    sql1 = sqlglot.parse_one(sql1)  # prediction
    sql2 = sqlglot.parse_one(sql2)  # target
    output1 = extract_all(sql1, schema)
    output2 = extract_all(sql2, schema)

    formatted_sql1 = output1['subqueries'][0]
    formatted_sql2 = output2['subqueries'][0]
    print('[SQL1]\n', formatted_sql1.sql(pretty=True))
    print()
    print('[SQL2]\n', formatted_sql2.sql(pretty=True))
    print()
    all_tsed, all_distance = compute_tsed(formatted_sql1, formatted_sql2, build_type='apted')  # apted or zss
    print(f'TSED: {all_tsed:.4f}')
    print(f'Tree Edit Distance: {all_distance}')


    # partial match
    print('Partial Match Score')
    sel_score = []
    for sel_ast1 in output1['sel_asts']:
        for sel_ast2 in output2['sel_asts']:
            tsed, distance = compute_tsed(sel_ast1[1], sel_ast2[1], build_type='apted')
            sel_score.append(tsed)
    print(f'  Selection: {np.mean(sel_score):.4f}')

    cond_score = []
    if output1['cond_asts'] and output2['cond_asts']:
        # both have conditions
        for cond_ast1 in output1['cond_asts']:
            for cond_ast2 in output2['cond_asts']:
                tsed, distance = compute_tsed(cond_ast1[1], cond_ast2[1], build_type='apted')
                cond_score.append(tsed)
        print(f'  Condition: {np.mean(cond_score):.4f}')
    elif output2['cond_asts']:
        # target has condition
        if not output1['cond_asts']:
            # prediction has no condition
            cond_score.append(0.0)
            print(f'  Condition: {0.0}')
    else:
        # both have no conditions
        cond_score.append(1.0)
        print(f'  Condition: {1.0}')

    agg_score = []
    if output1['agg_asts'] and output2['agg_asts']:
        # both have aggregation
        for agg_ast1 in output1['agg_asts']:
            for agg_ast2 in output2['agg_asts']:
                tsed, distance = compute_tsed(agg_ast1[1], agg_ast2[1], build_type='apted')
                agg_score.append(tsed)
        print(f'  Aggregation: {np.mean(agg_score):.4f}')
    elif output2['agg_asts']:
        # target has aggregation
        if not output1['agg_asts']:
            # prediction has no aggregation
            agg_score.append(0.0)
            print(f'  Aggregation: {0.0}')
    else:
        # both have no aggregation
        agg_score.append(1.0)
        print(f'  Aggregation: {1.0}')

    if output1['orderby_asts'] and output2['orderby_asts']:
        # both have orderby
        orderby_score = []
        for orderby_ast1 in output1['orderby_asts']:
            for orderby_ast2 in output2['orderby_asts']:
                tsed, distance = compute_tsed(orderby_ast1[1], orderby_ast2[1], build_type='apted')
                orderby_score.append(tsed)
        print(f'  Orderby: {np.mean(orderby_score):.4f}')
    elif output2['orderby_asts']:
        # target has orderby
        if not output1['orderby_asts']:
            # prediction has no orderby
            print(f'  Orderby: {0.0}')
    else:
        # both have no orderby
        print('  Orderby: 1.0 (no orderby)')

    nested_score = []
    if output1['subqueries'][1:] and output2['subqueries'][1:]:
        # both have nested queries
        for nested_ast1 in output1['subqueries'][1:]:
            for nested_ast2 in output2['subqueries'][1:]:
                tsed, distance = compute_tsed(nested_ast1, nested_ast2, build_type='apted')
                nested_score.append(tsed)
        print(f'  Nested: {np.mean(nested_score):.4f}')
    elif output2['subqueries'][1:]:
        # target has nested queries
        if not output1['subqueries'][1:]:
            # prediction has no nested queries
            nested_score.append(0.0)
            print(f'  Nested: {0.0}')
    else:
        # both have no nested queries
        nested_score.append(1.0)
        print(f'  Nested: {1.0}')

    if output1['distinct'] and output2['distinct']:
        # both have distinct
        print(f'  Distinct: 1.0')
    elif output2['distinct']:
        # target has distinct
        if not output1['distinct']:
            # prediction has no distinct
            print(f'  Distinct: 0.0')
    else:
        # both have no distinct
        print(f'  Distinct: {1.0}')

    if output1['limit'] and output2['limit']:
        # both have limit
        print(f'  Limit: 1.0')
    elif output2['limit']:
        # target has limit
        if not output1['limit']:
            # prediction has no limit
            print(f'  Limit: 0.0')
    else:
        # both have no limit
        print(f'  Limit: {1.0}')
        
    return all_tsed, all_distance, np.mean(sel_score), np.mean(cond_score), np.mean(agg_score), np.mean(nested_score)

In [6]:
import os
import json
question_path = 'data/tpch/questions_new.json'
test_queries = ['q12', 'q13', 'q14', 'q15', 'q16', 'q17', 'q18', 'q19', 'q20', 'q21']

def read_results(result_path):
    with open(question_path, 'r') as f:
        questions = json.load(f)

    for question in questions:
        ref_id = question['ref_id']
        if ref_id not in test_queries:
            continue
        with open(os.path.join(result_path, question['instance_id'])+"_gold.sql", 'r') as f:
            question['gold_sql'] = f.read()
        with open(os.path.join(result_path, question['instance_id'])+"_pred.sql", 'r') as f:
            question['pred_sql'] = f.read()

    NLsamples = []
    for i in range(len(questions)):
        q = questions[i]
        ref_id = q['ref_id']
        if ref_id not in test_queries:
            continue
        NLsamples.append({'instance_id': q['instance_id'],
                          'ref_id': q['ref_id'],
                          'question': q['question'], 
                          'gold_sql': q['gold_sql'],
                          'pred_sql': q['pred_sql'],
                         })        

    return NLsamples

In [7]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o_0_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.2703
Tree Edit Distance: 27
Partial Match Score
  Selection: 1.0000
  Condition: 0.3214
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey,
  COUNT(orders.o_orderkey)
FROM customer
LEFT JOIN ORDERS AS T2
  ON T1.C_CUSTKEY = T2.O_CUSTKEY
GROUP BY
  customer.c_custkey

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELECT
    customer.c_custkey,
    COUNT(orders.o_orderkey)
  FROM customer
  LEFT OUTER JOIN orders
    ON c_custkey = o_custkey
  GROUP BY
    customer.c_cu

In [8]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o_2_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.2703
Tree Edit Distance: 27
Partial Match Score
  Selection: 1.0000
  Condition: 0.3214
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey,
  COUNT(orders.o_orderkey)
FROM customer
LEFT JOIN ORDERS AS T2
  ON T1.C_CUSTKEY = T2.O_CUSTKEY
GROUP BY
  customer.c_custkey

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELECT
    customer.c_custkey,
    COUNT(orders.o_orderkey)
  FROM customer
  LEFT OUTER JOIN orders
    ON c_custkey = o_custkey
  GROUP BY
    customer.c_cu

TSED: 0.1471
Tree Edit Distance: 58
Partial Match Score
  Selection: 0.5000
  Condition: 0.1571
  Aggregation: 1.0000
  Orderby: 0.3333
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[ 0.10848391 60.          0.34073365  0.35803883  0.65        0.60928571]


In [9]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o_3_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND lineitem.l_receiptdate BETWEEN '[placeholder-type:string]' AND '[placeholder-type:string]'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.3243
Tree Edit Distance: 25
Partial Match Score
  Selection: 1.0000
  Condition: 0.3452
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey,
  COUNT(orders.o_orderkey)
FROM customer
LEFT JOIN ORDERS AS T2
  ON T1.C_CUSTKEY = T2.O_CUSTKEY
GROUP BY
  customer.c_custkey

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELECT
    customer.c_custkey,
    COUNT(orders.o_orderkey)
  FROM customer
  LEFT OUTER JOIN orders
    ON c_custkey = o_custkey
  GROU

In [10]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o-mini_0_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_shipdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]'

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.2703
Tree Edit Distance: 27
Partial Match Score
  Selection: 1.0000
  Condition: 0.3214
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey,
  COUNT(orders.o_orderkey)
FROM customer
LEFT JOIN ORDERS AS O
  ON C.C_CUSTKEY = O.O_CUSTKEY
GROUP BY
  customer.c_custkey

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELECT
    customer.c_custkey,
    COUNT(orders.o_orderkey)
  FROM customer
  LEFT OUTER JOIN orders
    ON c_custkey = o_custkey
  GROUP BY
    customer.c_custkey

In [11]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o-mini_2_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM orders
WHERE
  orders.o_orderdate >= '[placeholder-type:string]'
  AND orders.o_orderdate < '[placeholder-type:string]'
  AND EXISTS(
    SELECT
      *
    FROM lineitem
    WHERE
      lineitem.l_orderkey = orders.o_orderkey
      AND lineitem.l_commitdate < lineitem.l_receiptdate
  )

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.0488
Tree Edit Distance: 39
Partial Match Score
  Selection: 0.7500
  Condition: 0.1955
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey
FROM customer
LEFT JOIN ORDERS AS O
  ON C.C_CUSTKEY = O.O_CUSTKEY
GROUP BY
  customer.c_custkey
ORDER BY
  COUNT(orders.o_orderkey)

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELE

TSED: 0.0864
Tree Edit Distance: 74
Partial Match Score
  Selection: 0.3333
  Condition: 0.1571
  Aggregation: 0.3333
  Orderby: 0.3333
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[ 0.09581866 62.2         0.2504051   0.35293744  0.58333333  0.6       ]


In [13]:
scores = []
result_path = 'data/tpch/outputs/gpt-4o-mini_3_all/results/'
NLsamples = read_results(result_path)
for sample in NLsamples:
    sql1 = sample['pred_sql']
    sql2 = sample['gold_sql']
    tsed, distance, sel_score, cond_score, agg_score, nested_score = matching_score(sql1, sql2)
    scores.append([tsed, distance, sel_score, cond_score, agg_score, nested_score])
scores = np.array(scores)
print(np.mean(scores, axis=0))

[SQL1]
 SELECT
  COUNT(*)
FROM orders
WHERE
  orders.o_orderdate >= '[placeholder-type:string]'
  AND orders.o_orderdate < '[placeholder-type:string]'
  AND EXISTS(
    SELECT
      *
    FROM lineitem
    WHERE
      lineitem.l_orderkey = orders.o_orderkey
      AND lineitem.l_commitdate < lineitem.l_receiptdate
  )

[SQL2]
 SELECT
  COUNT(*)
FROM orders, lineitem
WHERE
  orders.o_orderkey = lineitem.l_orderkey
  AND lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.0488
Tree Edit Distance: 39
Partial Match Score
  Selection: 0.7500
  Condition: 0.1955
  Aggregation: 1.0
  Orderby: 1.0 (no orderby)
  Nested: 1.0
  Distinct: 1.0
  Limit: 1.0
[SQL1]
 SELECT
  customer.c_custkey
FROM customer
LEFT JOIN ORDERS AS O
  ON C.C_CUSTKEY = O.O_CUSTKEY
GROUP BY
  customer.c_custkey
ORDER BY
  COUNT(orders.o_orderkey)

[SQL2]
 SELECT
  None.c_count,
  COUNT(*)
FROM (
  SELE