In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import sqlglot
import numpy as np
from sqlglot import expressions as exp
from src.parsing_sql import Schema, extract_all
from src.eval_utils import (
    partial_match, 
    compute_tsed
)

from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _extract_columns_from_expression,
    _determine_tag,
    _format_expression,
    _get_full_column_name,
    extract_aliases,
    extract_selection,
    extract_aggregation,
    extract_orderby,
    extract_others,
    
    _extract_aliases_from_select,
    _handle_table_or_subquery
)

In [3]:
schema_dict = {'lists': {'user_id': 'text',
  'list_id': 'text',
  'list_title': 'text',
  'list_movie_number': 'text',
  'list_update_timestamp_utc': 'text',
  'list_creation_timestamp_utc': 'text',
  'list_followers': 'text',
  'list_url': 'text',
  'list_comments': 'text',
  'list_description': 'text',
  'list_cover_image_url': 'text',
  'list_first_image_url': 'text',
  'list_second_image_url': 'text',
  'list_third_image_url': 'text'},
 'movies': {'movie_id': 'integer',
  'movie_title': 'integer',
  'movie_release_year': 'integer',
  'movie_url': 'integer',
  'movie_title_language': 'integer',
  'movie_popularity': 'integer',
  'movie_image_url': 'integer',
  'director_id': 'integer',
  'director_name': 'integer',
  'director_url': 'integer'},
 'ratings_users': {'user_id': 'integer',
  'rating_date_utc': 'integer',
  'user_trialist': 'integer',
  'user_subscriber': 'integer',
  'user_avatar_image_url': 'integer',
  'user_cover_image_url': 'integer',
  'user_eligible_for_trial': 'integer',
  'user_has_payment_method': 'integer'},
 'lists_users': {'user_id': 'text',
  'list_id': 'text',
  'list_update_date_utc': 'text',
  'list_creation_date_utc': 'text',
  'user_trialist': 'text',
  'user_subscriber': 'text',
  'user_avatar_image_url': 'text',
  'user_cover_image_url': 'text',
  'user_eligible_for_trial': 'text',
  'user_has_payment_method': 'text'},
 'ratings': {'movie_id': 'integer',
  'rating_id': 'integer',
  'rating_url': 'integer',
  'rating_score': 'integer',
  'rating_timestamp_utc': 'integer',
  'critic': 'integer',
  'critic_likes': 'integer',
  'critic_comments': 'integer',
  'user_id': 'integer',
  'user_trialist': 'integer',
  'user_subscriber': 'integer',
  'user_eligible_for_trial': 'integer',
  'user_has_payment_method': 'integer'}}

sqls = """
SELECT movie_release_year FROM movies WHERE movie_title = 'Cops'
SELECT T1.user_id FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE rating_score = 4 AND rating_timestamp_utc LIKE '2013-05-04 06:33:32' AND T2.movie_title LIKE 'Freaks'
SELECT T1.user_trialist FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T2.movie_title = 'A Way of Life' AND T1.user_id = 39115684
SELECT T2.movie_title FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T1.rating_timestamp_utc LIKE '2020%' GROUP BY T2.movie_title ORDER BY COUNT(T2.movie_title) DESC LIMIT 1
SELECT AVG(T1.rating_score), T2.director_name FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T2.movie_title = 'When Will I Be Loved'
"""
schema = Schema(schema_dict)
sqls = [s.strip() for s in sqls.strip().split('\n')]

In [None]:
for sql in sqls:
    parsed_sql = sqlglot.parse_one(sql)
    output = extract_all(parsed_sql, schema)
    # print
    print('SQL:', sql)
    print('# Selection')
    print(f'  unique columns: {output["sel"]}')
    for i, ast in enumerate(output['sel_asts']):
        print(f' [{i}] type: {ast[2]}')
        print(f' [{i}] ast:')
        print('  ' + repr(ast[1]))
    if output['cond_asts']:
        print('\n# condition')
        print(f'  operations: {output["op_types"]}')
        for i, ast in enumerate(output['cond_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    if output['agg_asts']:
        print('\n# aggregation')
        print(f'  unique columns: {output["agg"]}')
        for i, ast in enumerate(output['agg_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    if output['orderby_asts']:
        print('\n# orderby')
        print(f'  unique columns: {output["orderby"]}')
        for i, ast in enumerate(output['group_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    
    if output['nested']:
        print('\n# nested')
        print(f'  number of nested: {output["nested"]}')
        # check the `output['subqueries']` if you waht to see the nested queries
        # first one is the original query
    if output['distinct']:
        print(f'\n# distinct: {output["distinct"]}')
    if output['limit']:
        print(f'\n# limit: {output["limit"]}')
    print('----------------------------------')

# Measurement of Complexity

1. Tree Similarity Edit Distance
2. Set of unique columns, tables, types of functions ...

In [92]:
import sqlglot

# sql1 = """SELECT T1.USER_ID 
# FROM ratings AS T1 
# INNER JOIN movies AS T2 
# ON T1.movie_id = T2.movie_id 
# WHERE 
#     rating_score = 4 
#     AND rating_timestamp_utc LIKE '2013-05-04 06:33:32' 
#     AND T2.movie_title LIKE 'Freaks'
# """

# sql2 = """SELECT T1.user_id, COUNT(T2.movie_title)
# FROM ratings AS T1 
# INNER JOIN movies AS T2 
# ON T1.movie_id = T2.movie_id 
# GROUP BY T1.user_id
# HAVING COUNT(T2.movie_title) > 1
# ORDER BY COUNT(T2.movie_title) DESC
# """

sql1 = """SELECT
  COUNT(*) AS late_line_items_count
FROM LINEITEM L
WHERE
  lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE
  AND STRFTIME('%Y', lineitem.L_RECEIPTDATE) = 'abcd'"""

sql2 = """SELECT
  COUNT(*) AS count
FROM lineitem
WHERE
  lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '1993-01-01'
  AND lineitem.l_receiptdate < '1994-01-01'
"""

schema = Schema({
    'lineitem': {'l_receiptdate': 'date', 'l_commitdate': 'date'}
})

output1 = extract_all(sql1, schema)
output2 = extract_all(sql2, schema)


formatted_sql1 = output1['subqueries'][0]
formatted_sql2 = output2['subqueries'][0]
tsed, distance = compute_tsed(formatted_sql1, formatted_sql2, build_type='apted')  # apted or zss
print('[SQL1]\n', formatted_sql1.sql(pretty=True))
print()
print('[SQL2]\n', formatted_sql2.sql(pretty=True))
print()
print(f'TSED: {tsed:.4f}')
print(f'Tree Edit Distance: {distance}')


# partial match
print('Partial Match Score')

def get_partial_score(output1, output2, arg):
    """
    table:

    target |  prediction  |  score
    True   |  True        |  depends on arg
    True   |  False       |  0.0, np.infty
    False  |  True        |  0.0, np.infty
    False  |  False       |  1.0, 0.0
    
    arg: 
     - use all: 'sel_asts', 'cond_asts', 'agg_asts', 'orderby_asts'
     - only use items from 2nd item in the list: 'subqueries'
     - boolean: 'distinct', 'limit'
    """
    if output2[arg] and output1[arg]:
        if arg in ['sel_asts', 'cond_asts', 'agg_asts', 'orderby_asts']:
            source = [ast for _, ast, _ in output1[arg]]
            target = [ast for _, ast, _ in output2[arg]]
            score, dis = get_tsed_score(source, target, build_type='apted')
        elif arg == 'subqueries':
            source = output1[arg][1:]
            target = output2[arg][1:]
            score, dis = get_tsed_score(source, target, build_type='apted')
        elif arg in ['distinct', 'limit']:
            score, dis = 1.0, 0.0
    elif (not output2[arg]) and (not output1[arg]):
        score, dis = 1.0, 0.0    
    else:
        score, dis = 0.0, np.infty

    return score, dis

def get_tsed_score(
        source: list[exp.Expression], 
        target: list[exp.Expression], 
        build_type='apted',
        criteria='tsed'  # tsed or distance
    ):
    """
    1. calculate pairwise tsed
    2. check possible matchings
    3. choose the matching with the highest score
    4. return the overall score

    criteria: tsed (max) or distance (min)
    """
    scores = np.zeros((len(source), len(target)))
    distances = np.ones((len(source), len(target))) * np.infty

    for i, ast1 in enumerate(source):
        for j, ast2 in enumerate(target):
            score, distance = compute_tsed(ast1, ast2, build_type)
            scores[i, j] = score
            distances[i, j] = distance

    return scores, distances

# sel_score, sel_dis = get_partial_score(output1, output2, arg='sel_asts')
# print(f'  Selection: tsed={sel_score:.4f} | distance={sel_dis:.2f}')
# cond_score, cond_dis = get_partial_score(output1, output2, arg='cond_asts')
# print(f'  Condition: tsed={cond_score:.4f} | distance={cond_dis:.2f}')
# agg_score, agg_dis = get_partial_score(output1, output2, arg='agg_asts')
# print(f'  Aggregation: tsed={agg_score:.4f} | distance={agg_dis:.2f}')
# orderby_score, orderby_dis = get_partial_score(output1, output2, arg='orderby_asts')
# print(f'  Orderby: tsed={orderby_score:.4f} | distance={orderby_dis:.2f}')
# nested_score, nested_dis = get_partial_score(output1, output2, arg='subqueries')
# print(f'  Nested: tsed={nested_score:.4f} | distance={nested_dis:.2f}')
# distinct_score, distinct_dis = get_partial_score(output1, output2, arg='distinct')
# print(f'  Distinct: tsed={distinct_score:.4f} | distance={distinct_dis:.2f}')
# limit_score, limit_dis = get_partial_score(output1, output2, arg='limit')
# print(f'  Limit: tsed={limit_score:.4f} | distance={limit_dis:.2f}')

source = [ast for _, ast, _ in output1['cond_asts']]
target = [ast for _, ast, _ in output2['cond_asts']]
score, dis = get_tsed_score(source, target, build_type='apted')

[SQL1]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_receiptdate > lineitem.l_commitdate
  AND STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]'

[SQL2]
 SELECT
  COUNT(*)
FROM lineitem
WHERE
  lineitem.l_commitdate < lineitem.l_receiptdate
  AND lineitem.l_receiptdate >= '[placeholder-type:string]'
  AND lineitem.l_receiptdate < '[placeholder-type:string]'

TSED: 0.4231
Tree Edit Distance: 15
Partial Match Score


In [106]:
import spacy
try:
    nlp_spacy = spacy.load('en_core_web_md')
except OSError:
    from spacy.cli import download
    download('en_core_web_md')

from bert_score import score as bscore

In [104]:
source_spacy = [nlp_spacy(str(x)) for x in source]
target_spacy = [nlp_spacy(str(x)) for x in target]

for s in source_spacy:
    for t in target_spacy:
        print(f'{s.similarity(t):.5f}', s, t, )

0.97697 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_receiptdate < '[placeholder-type:string]'
0.98193 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_receiptdate >= '[placeholder-type:string]'
1.00000 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_commitdate < lineitem.l_receiptdate
0.99397 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_receiptdate < '[placeholder-type:string]'
0.99266 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_receiptdate >= '[placeholder-type:string]'
0.96241 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_commitdate < lineitem.l_receiptdate


In [120]:
from itertools import product
source_str = [str(x) for x in source]
target_str = [str(x) for x in target]
source_str_list, target_str_list = list(zip(*product(source_str, target_str)))
P, R, F1 = bscore(source_str_list, target_str_list, lang='en', verbose=False)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [122]:
for i, (s, t) in enumerate(zip(source_str_list, target_str_list)):
    print(f'{P[i]:.5f}, {R[i]:.5f}, {F1[i]:.5f}', s, t)

0.91560, 0.86583, 0.89002 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_receiptdate < '[placeholder-type:string]'
0.91453, 0.86756, 0.89043 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_receiptdate >= '[placeholder-type:string]'
0.98415, 0.98415, 0.98415 lineitem.l_receiptdate > lineitem.l_commitdate lineitem.l_commitdate < lineitem.l_receiptdate
0.90197, 0.96393, 0.93192 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_receiptdate < '[placeholder-type:string]'
0.90342, 0.96516, 0.93327 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_receiptdate >= '[placeholder-type:string]'
0.82550, 0.91720, 0.86893 STRFTIME('%Y', lineitem.l_receiptdate) = '[placeholder-type:string]' lineitem.l_commitdate < lineitem.l_receiptdate


In [95]:
score

array([[0.42857143, 0.42857143, 0.28571429],
       [0.57142857, 0.57142857, 0.14285714]])

In [89]:
dis

array([[4., 4., 5.],
       [3., 3., 6.]])

In [123]:
np.argmax(score, axis=1)

array([0, 0])

In [148]:
# use ranking of two metrics to determine the best matching

Table(
  this=Identifier(this=LINEITEM, quoted=False),
  alias=TableAlias(
    this=Identifier(this=L, quoted=False)))

In [140]:
expr.args['this']

Identifier(this=LINEITEM, quoted=False)

In [141]:
exp.Identifier(
    this=expr.args['this'].name.lower(), 
    quoted=expr.args['this'].quoted   
)

Identifier(this=lineitem, quoted=False)