In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import sqlglot
import numpy as np
from sqlglot import expressions as exp
from src.parsing_sql import Schema, extract_all
from src.eval_utils import (
    partial_match, 
    compute_tsed
)

from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _extract_columns_from_expression,
    _determine_tag,
    _format_expression,
    _get_full_column_name,
    extract_aliases,
    extract_selection,
    extract_aggregation,
    extract_orderby,
    extract_others,
    
    _extract_aliases_from_select,
    _handle_table_or_subquery
)

In [3]:
schema_dict = {'lists': {'user_id': 'text',
  'list_id': 'text',
  'list_title': 'text',
  'list_movie_number': 'text',
  'list_update_timestamp_utc': 'text',
  'list_creation_timestamp_utc': 'text',
  'list_followers': 'text',
  'list_url': 'text',
  'list_comments': 'text',
  'list_description': 'text',
  'list_cover_image_url': 'text',
  'list_first_image_url': 'text',
  'list_second_image_url': 'text',
  'list_third_image_url': 'text'},
 'movies': {'movie_id': 'integer',
  'movie_title': 'integer',
  'movie_release_year': 'integer',
  'movie_url': 'integer',
  'movie_title_language': 'integer',
  'movie_popularity': 'integer',
  'movie_image_url': 'integer',
  'director_id': 'integer',
  'director_name': 'integer',
  'director_url': 'integer'},
 'ratings_users': {'user_id': 'integer',
  'rating_date_utc': 'integer',
  'user_trialist': 'integer',
  'user_subscriber': 'integer',
  'user_avatar_image_url': 'integer',
  'user_cover_image_url': 'integer',
  'user_eligible_for_trial': 'integer',
  'user_has_payment_method': 'integer'},
 'lists_users': {'user_id': 'text',
  'list_id': 'text',
  'list_update_date_utc': 'text',
  'list_creation_date_utc': 'text',
  'user_trialist': 'text',
  'user_subscriber': 'text',
  'user_avatar_image_url': 'text',
  'user_cover_image_url': 'text',
  'user_eligible_for_trial': 'text',
  'user_has_payment_method': 'text'},
 'ratings': {'movie_id': 'integer',
  'rating_id': 'integer',
  'rating_url': 'integer',
  'rating_score': 'integer',
  'rating_timestamp_utc': 'integer',
  'critic': 'integer',
  'critic_likes': 'integer',
  'critic_comments': 'integer',
  'user_id': 'integer',
  'user_trialist': 'integer',
  'user_subscriber': 'integer',
  'user_eligible_for_trial': 'integer',
  'user_has_payment_method': 'integer'}}

sqls = """
SELECT movie_release_year FROM movies WHERE movie_title = 'Cops'
SELECT T1.user_id FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE rating_score = 4 AND rating_timestamp_utc LIKE '2013-05-04 06:33:32' AND T2.movie_title LIKE 'Freaks'
SELECT T1.user_trialist FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T2.movie_title = 'A Way of Life' AND T1.user_id = 39115684
SELECT T2.movie_title FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T1.rating_timestamp_utc LIKE '2020%' GROUP BY T2.movie_title ORDER BY COUNT(T2.movie_title) DESC LIMIT 1
SELECT AVG(T1.rating_score), T2.director_name FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T2.movie_title = 'When Will I Be Loved'
"""
schema = Schema(schema_dict)
sqls = [s.strip() for s in sqls.strip().split('\n')]

In [4]:
for sql in sqls:
    output = extract_all(sql, schema)
    # print
    print('SQL:', sql)
    print('# Selection')
    print(f'  unique columns: {output["sel"]}')
    for i, ast in enumerate(output['sel_asts']):
        print(f' [{i}] type: {ast[2]}')
        print(f' [{i}] ast:')
        print('  ' + repr(ast[1]))
    if output['cond_asts']:
        print('\n# condition')
        print(f'  operations: {output["op_types"]}')
        for i, ast in enumerate(output['cond_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    if output['agg_asts']:
        print('\n# aggregation')
        print(f'  unique columns: {output["agg"]}')
        for i, ast in enumerate(output['agg_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    if output['orderby_asts']:
        print('\n# orderby')
        print(f'  unique columns: {output["orderby"]}')
        for i, ast in enumerate(output['group_asts']):
            print(f' [{i}] {ast[0]}')
            print(f' [{i}] ast:')
            print('  ' + repr(ast[1]))
    
    if output['nested']:
        print('\n# nested')
        print(f'  number of nested: {output["nested"]}')
        # check the `output['subqueries']` if you waht to see the nested queries
        # first one is the original query
    if output['distinct']:
        print(f'\n# distinct: {output["distinct"]}')
    if output['limit']:
        print(f'\n# limit: {output["limit"]}')
    print('----------------------------------')

SQL: SELECT movie_release_year FROM movies WHERE movie_title = 'Cops'
# Selection
  unique columns: {'__movies.movie_release_year__'}
 [0] type: <select>
 [0] ast:
  Column(
  this=Identifier(this=movie_release_year, quoted=False),
  table=Identifier(this=movies, quoted=False))

# condition
  operations: {'eq'}
 [0] __movies.movie_title__ eq [placeholder-type:string]
 [0] ast:
  EQ(
  this=Column(
    this=Identifier(this=movie_title, quoted=False),
    table=Identifier(this=movies, quoted=False)),
  expression=Literal(this=[placeholder-type:string], is_string=True))

# nested
  number of nested: 1
----------------------------------
SQL: SELECT T1.user_id FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE rating_score = 4 AND rating_timestamp_utc LIKE '2013-05-04 06:33:32' AND T2.movie_title LIKE 'Freaks'
# Selection
  unique columns: {'__ratings.user_id__'}
 [0] type: <select>
 [0] ast:
  Column(
  this=Identifier(this=user_id, quoted=False),
  table=Identif

# Measurement of Structural Similarity between Source and Target ASTs

* `n` = number of source asts
* `m` = number of target asts

```python
if n == m:
    # means that the number of source and target asts are the same
if n != m:
    # means that the number of source asts are greater/smaller than the number of target asts
    # need to give penalty for the missing asts or the extra asts
```


Hungarian algorithm - https://hongl.tistory.com/159

* semantic similarity and structural similarity
    * semantic similarity - bertscore
    * structural similarity - tree similarity edit distance

In [22]:
# import sqlglot

sql1 = """SELECT T1.USER_ID 
FROM ratings AS T1 
INNER JOIN movies AS T2 
ON T1.movie_id = T2.movie_id 
WHERE 
    rating_score = 4 
    AND rating_timestamp_utc LIKE '2013-05-04 06:33:32' 
    AND T2.movie_title LIKE 'Freaks'
"""

sql2 = """SELECT T1.user_id, COUNT(T2.movie_title)
FROM ratings AS T1 
INNER JOIN movies AS T2 
ON T1.movie_id = T2.movie_id 
GROUP BY T1.user_id
HAVING COUNT(T2.movie_title) > 1
ORDER BY COUNT(T2.movie_title) DESC
"""
schema = Schema(schema_dict)

# sql1 = """SELECT
#   COUNT(*) AS count, lineitem.l_returnflag
# FROM lineitem
# WHERE
#   lineitem.l_commitdate < lineitem.l_receiptdate
#   AND lineitem.l_receiptdate >= '1993-01-01'
#   AND lineitem.l_receiptdate < '1994-01-01'
# """

# sql2 = """SELECT
#   COUNT(lineitem.L_RECEIPTDATE), lineitem.l_returnflag
# FROM LINEITEM L
# WHERE
#   lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE
#   AND STRFTIME('%Y', lineitem.L_RECEIPTDATE) = 'abcd'
# """

# schema = Schema({
#     'lineitem': {'l_receiptdate': 'date', 'l_commitdate': 'date'}
# })

output1 = extract_all(sql1, schema)
output2 = extract_all(sql2, schema)

formatted_sql1 = output1['subqueries'][0]
formatted_sql2 = output2['subqueries'][0]
tsed, distance = compute_tsed(formatted_sql1, formatted_sql2, build_type='apted')  # apted or zss
print('[SQL1]\n', formatted_sql1.sql(pretty=True))
print()
print('[SQL2]\n', formatted_sql2.sql(pretty=True))
print()
print(f'TSED: {tsed:.4f}')
print(f'Tree Edit Distance: {distance}')

# partial match
print('Partial Match Score')
from src.eval_utils import get_all_partial_score

results, overall_score = get_all_partial_score(output1, output2, build_type='apted', criteria='tsed', penalty=0.01, use_bert=True, rescale_with_baseline=True)
for k, v in results.items():
    if v:
        print(f'- {k}: {v:.4f}')
    else:
        print(f'- {k}: {v}')
print(f'Overall Score: {overall_score:.4f}')

[SQL1]
 SELECT
  ratings.user_id
FROM ratings
INNER JOIN movies
  ON T1.movie_id = T2.movie_id
WHERE
  ratings.rating_score = [placeholder-type:numeric]
  AND ratings.rating_timestamp_utc LIKE '[placeholder-type:string]'
  AND movies.movie_title LIKE '[placeholder-type:string]'

[SQL2]
 SELECT
  ratings.user_id,
  COUNT(movies.movie_title)
FROM ratings
INNER JOIN movies
  ON T1.movie_id = T2.movie_id
GROUP BY
  ratings.user_id
HAVING
  COUNT(movies.movie_title) > [placeholder-type:numeric]
ORDER BY
  COUNT(movies.movie_title)

TSED: 0.4054
Tree Edit Distance: 22
Partial Match Score
table_asts True True True
sel_asts True True True
cond_asts True True True
agg_asts False True False
orderby_asts False True False
subqueries False False False
distinct False False False
limit False False False
- table_asts: 1.0000
- sel_asts: 0.4950
- cond_asts: 0.1698
- agg_asts: 0.0
- orderby_asts: 0.0
- subqueries: None
- distinct: None
- limit: None
Overall Score: 0.0000


# Complexity of SQL

In [16]:
def normalize_values(x, min_value=0, max_value=6):
    normalized = (x - min_value) / (max_value - min_value)
    return normalized

def tanh(x: np.ndarray, k: float):
    normalized = normalize_values(x, max_value=k)
    return np.tanh(np.log(1+normalized.sum()))

def derive_complexity(x: list[int], k=6):
    complexity = tanh(np.array(x), k)
    return complexity

def get_complexity(output, k=6):
    args1 = [('sel', 'sel_asts'), ('cond_asts', 'op_types'), ('agg', 'agg_asts'), ('orderby', 'orderby_asts')]
    args2 = ['distinct', 'limit', 'nested', 'table_asts']
    total_complexity = []
    for arg in args1:
        exists = all([output[arg[0]], output[arg[1]]])
        if exists:
            x = [len(output[arg[0]]), len(output[arg[1]])]
            complexity = derive_complexity(x, k=k)
            total_complexity.append(complexity)
    
    for arg in args2:
        if output[arg]:
            if arg == 'nested':
                complexity = derive_complexity([output[arg]], k=k)
            elif arg == 'table_asts':
                complexity = derive_complexity([len(output[arg])], k=k)
            else:
                complexity = derive_complexity([int(output[arg])], k=k)
            total_complexity.append(complexity)
    return np.mean(total_complexity)

In [17]:
get_complexity(output1), get_complexity(output2)

(0.3294679655301611, 0.31176470588235294)