In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import DatabaseModel, QuestionSQL, SparcSample, SpiderSample, Description
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_spider,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)
from copy import deepcopy
bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)

In [14]:
spider_path = proj_path / 'data' / 'spider'
tables, train_data, dev_data = load_raw_data(spider_path, load_test=False)

with (proj_path / 'data' / 'description.json').open() as f:
    all_descriptions = json.load(f)

spider_tables = process_all_tables(tables, descriptions=all_descriptions)
all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)
skip = [622, 6916, 6917, 6930, 6967, 6987]
spider_samples = process_samples_spider(all_data, spider_tables, skip=skip)
train_samples, dev_samples, test_samples = split_train_dev_test(spider_samples, train_ratio=0.6, dev_ratio=0.2)
# makesure the dev/test sql is not in the train sql

save_samples_spider_bird(train_samples, proj_path / 'data' / 'spider_train.json')
save_samples_spider_bird(dev_samples, proj_path / 'data' / 'spider_dev.json')
save_samples_spider_bird(test_samples, proj_path / 'data' / 'spider_test.json')
print(len(train_samples), len(dev_samples), len(test_samples))

# train_samples = load_samples_spider_bird(proj_path / 'data' / 'spider_train.json')
# dev_samples = load_samples_spider_bird(proj_path / 'data' / 'spider_dev.json')
# test_samples = load_samples_spider_bird(proj_path / 'data' / 'spider_test.json')

100%|██████████| 8023/8023 [00:01<00:00, 4143.27it/s]

4760 1555 1702





In [11]:
# all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)
# with open(proj_path / 'data' / 'bird_skip.txt') as f:
#     skip = [int(line.strip()) for line in f]

# bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
# train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2)
# makesure the dev/test sql is not in the train sql

# save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
# save_samples_spider_bird(dev_samples, proj_path / 'data' / 'bird_dev.json')
# save_samples_spider_bird(test_samples, proj_path / 'data' / 'bird_test.json')
# print(len(train_samples), len(dev_samples), len(test_samples))

100%|██████████| 10956/10956 [00:03<00:00, 3053.34it/s]


In [12]:
# import pickle
# with (proj_path / 'data' / 'bird_samples.pkl').open('wb') as f:
#     pickle.dump(bird_samples, f)

# with (proj_path / 'data' / 'bird_samples.pkl').open('rb') as f:
#     bird_samples = pickle.load(f)

In [13]:
# df_train = pd.DataFrame({
#     'db_id': [x.db_id for x in train_samples], 
#     'sql': [x.final.sql for x in train_samples],
#     'sample_id': [x.sample_id for x in train_samples]
# })
# df_dev = pd.DataFrame({
#     'db_id': [x.db_id for x in dev_samples] + [x.db_id for x in test_samples],
#     'sql': [x.final.sql for x in dev_samples] + [x.final.sql for x in test_samples],
#     'sample_id': [x.sample_id for x in dev_samples] + [x.sample_id for x in test_samples]
# })

# df_train.to_csv(proj_path / 'data' / 'bird_train.csv', index=False)
# df_dev.to_csv(proj_path / 'data' / 'bird_dev.csv', index=False)

# df_train = pd.read_csv(proj_path / 'data' / 'bird_train.csv')
# df_dev = pd.read_csv(proj_path / 'data' / 'bird_dev.csv')

In [15]:
import sqlglot
from src.parsing_sql import extract_all, Schema
from src.eval_utils import (
    get_structural_score,
    get_semantic_score,
    get_all_partial_score, 
    get_complexity,
    get_all_structural_score
)
import pickle

In [17]:
error_ids = []
results = defaultdict()
parsed = defaultdict(list)
tables = spider_tables
for db_id, samples in spider_samples.items():
    schema = Schema(tables[db_id].db_schema)
    iterator = tqdm(range(len(samples)), total=len(samples), desc=f"{db_id}")
    for i in iterator:
        sql_i = samples[i].final.sql
        try:
            ei = extract_all(sql_i, schema)
            assert len(ei['sel']) > 0, f'No selection found-{db_id}-{i}'
        except Exception as e:
            error_ids.append((db_id, i, samples[i].sample_id, str(e)))
            parsed[db_id].append(None)
            continue
        parsed[db_id].append(ei)

# with (proj_path / 'data' / 'bird_parsed.pkl').open('wb') as f:
#     pickle.dump(parsed, f)

department_management: 100%|██████████| 16/16 [00:00<00:00, 1121.45it/s]
farm: 100%|██████████| 40/40 [00:00<00:00, 1459.78it/s]
student_assessment: 100%|██████████| 53/53 [00:00<00:00, 1144.85it/s]
bike_1: 100%|██████████| 104/104 [00:00<00:00, 1120.69it/s]
book_2: 100%|██████████| 21/21 [00:00<00:00, 1356.98it/s]
musical: 100%|██████████| 40/40 [00:00<00:00, 1373.12it/s]
twitter_1: 100%|██████████| 27/27 [00:00<00:00, 1127.60it/s]
product_catalog: 100%|██████████| 42/42 [00:00<00:00, 1387.96it/s]
flight_1: 100%|██████████| 96/96 [00:00<00:00, 1299.23it/s]
allergy_1: 100%|██████████| 98/98 [00:00<00:00, 1052.05it/s]
store_1: 100%|██████████| 111/111 [00:00<00:00, 1083.61it/s]
journal_committee: 100%|██████████| 18/18 [00:00<00:00, 1206.32it/s]
customers_card_transactions: 100%|██████████| 80/80 [00:00<00:00, 1375.66it/s]
race_track: 100%|██████████| 42/42 [00:00<00:00, 1319.46it/s]
coffee_shop: 100%|██████████| 18/18 [00:00<00:00, 1266.93it/s]
chinook_1: 100%|██████████| 84/84 [00:00<

In [18]:
len(error_ids), error_ids

(0, [])

In [33]:
from itertools import combinations

epsilon: float=1e-9

for db_id, samples in bird_samples.items():
    semantic_sim = np.zeros((len(samples), len(samples))) + np.eye(len(samples))
    structural_sim = np.zeros((len(samples), len(samples))) + np.eye(len(samples))
    overall_sim = np.zeros((len(samples), len(samples))) + np.eye(len(samples))
    
    idxs = list(combinations(range(len(samples)), 2))

    iterator = tqdm(idxs, total=len(idxs), desc=f"{db_id}")
    for i, j in iterator:
        if parsed[db_id][i] is None:
            continue
        if parsed[db_id][j] is None:
            continue
        ei = parsed[db_id][i]        
        ej = parsed[db_id][j]
            
        _, structural_score = get_all_structural_score(ei, ej)
        structural_sim[i, j] = structural_score
        structural_sim[j, i] = structural_score

        # _, final_score = get_all_partial_score(ei, ej, use_bert=False)

        # structural_sim[i, j] = final_score['structural']
        # semantic_sim[j, i] = final_score['semantic']
        # overall_sim[i, j] = final_score['overall']

        # structural_sim[j, i] = final_score['structural']
        # semantic_sim[j, i] = final_score['semantic']
        # overall_sim[j, i] = final_score['overall']

    # results[db_id] = {
    #     'semantic': semantic_sim,
    #     'struct': structural_sim,
    #     'overall': overall_sim
    # }
    results[db_id] = structural_sim

with (proj_path / 'data' / 'bird_similarity.pkl').open('wb') as f:
    pickle.dump(results, f)

movie_platform:   0%|          | 0/13861 [00:00<?, ?it/s]

movie_platform: 100%|██████████| 13861/13861 [00:25<00:00, 551.75it/s]
book_publishing_company: 100%|██████████| 2628/2628 [00:05<00:00, 451.58it/s]
retail_complains: 100%|██████████| 14028/14028 [00:36<00:00, 379.87it/s]
movies_4: 100%|██████████| 12403/12403 [00:28<00:00, 433.16it/s]
codebase_comments: 100%|██████████| 7503/7503 [00:14<00:00, 522.07it/s]
trains: 100%|██████████| 780/780 [00:05<00:00, 143.41it/s]
movie: 100%|██████████| 1035/1035 [00:02<00:00, 411.75it/s]
social_media: 100%|██████████| 3003/3003 [00:06<00:00, 469.62it/s]
cs_semester: 100%|██████████| 6328/6328 [00:22<00:00, 286.20it/s]
computer_student: 100%|██████████| 2556/2556 [00:04<00:00, 613.34it/s]
talkingdata: 100%|██████████| 21115/21115 [01:08<00:00, 308.74it/s]
law_episode: 100%|██████████| 6441/6441 [00:14<00:00, 442.36it/s]
synthea: 100%|██████████| 17020/17020 [00:54<00:00, 313.55it/s]
car_retails: 100%|██████████| 7875/7875 [00:20<00:00, 388.87it/s]
restaurant: 100%|██████████| 6786/6786 [00:10<00:00, 6

In [36]:
results[db_id].mean(axis=0), results[db_id].std(axis=0)

(array([0.12260781, 0.096525  , 0.251725  , 0.21155156, 0.24779219,
        0.12495156, 0.19441406, 0.22925781, 0.19630625, 0.22925781,
        0.23320938, 0.12773281, 0.13516875, 0.14355625, 0.1277125 ,
        0.15592187, 0.14828438, 0.24885   , 0.23357813, 0.22532656,
        0.2066    , 0.11531094, 0.15339219, 0.16306562, 0.13000313,
        0.25411719, 0.2490625 , 0.25685313, 0.19337031, 0.21204688,
        0.1798875 , 0.204775  , 0.18465156, 0.16548281, 0.15114688,
        0.17525   , 0.147575  , 0.13854375, 0.18831563, 0.19460313,
        0.1861625 , 0.20758594, 0.17511406, 0.2056125 , 0.19446875,
        0.21154687, 0.21319844, 0.17478594, 0.20631094, 0.19719531,
        0.20636562, 0.19647188, 0.199425  , 0.18439688, 0.20267344,
        0.21765625, 0.15651094, 0.12222031, 0.08344063, 0.18713125,
        0.12469062, 0.16565937, 0.17372031, 0.19274844]),
 array([0.17034594, 0.15414484, 0.26850189, 0.18522114, 0.26558842,
        0.14104186, 0.16740815, 0.23658794, 0.22516845, 0.

In [20]:
import warnings
import sqlglot
import numpy as np

from zss import simple_distance, Node
from sqlglot import expressions as exp
from apted import APTED
from apted.helpers import Tree
from typing import Tuple
from itertools import product
from scipy.optimize import linear_sum_assignment 
from transformers import logging as tfloggings
from src.eval_utils import (
    get_partial_score,
    get_all_partial_score,
    get_final_score,
    get_semantic_score,
    get_structural_score,
    partial_matching_with_penalty,
    build_tree,
    _build_node,
    _build_tree,
    compute_tsed
)

In [21]:
build_type: str='apted'
criteria: str='tsed'
penalty: float = 0.01
use_bert: bool = True
rescale_with_baseline: bool = True

source_output = ei
target_output = ej
args = ['table_asts', 'sel_asts', 'cond_asts', 'agg_asts', 'orderby_asts', 'subqueries', 'distinct', 'limit']
results = {}
for arg in args:
    # structural_score, semantic_score, score = get_partial_score(source_output, target_output, arg)
    source_exists = bool(source_output[arg]) if arg != 'subqueries' else bool(source_output[arg][1:])
    target_exists = bool(target_output[arg]) if arg != 'subqueries' else bool(target_output[arg][1:])

    if target_exists and source_exists:
        if arg in ['sel_asts', 'cond_asts', 'agg_asts', 'orderby_asts', 'table_asts']:
            source = [ast for _, ast, _ in source_output[arg]]
            target = [ast for _, ast, _ in target_output[arg]]
            structural_score, semantic_score, score = get_final_score(source, target, build_type, criteria, penalty, use_bert, rescale_with_baseline)
        elif arg == 'subqueries':
            source = source_output[arg][1:]
            target = target_output[arg][1:]
            structural_score, semantic_score, score = get_final_score(source, target, build_type, criteria, penalty, use_bert, rescale_with_baseline)
        elif arg in ['distinct', 'limit']:
            score = 1.0 if criteria == 'tsed' else 0.0
            structural_score, semantic_score = score, score
    elif target_exists ^ source_exists:
        score = 0.0 if criteria == 'tsed' else np.infty
        structural_score, semantic_score = score, score 
    else:
        # they don't exist in both so, we can't measure the score
        score = None
        structural_score, semantic_score = score, score

# structural_score = get_structural_score(source, target, build_type, criteria, penalty)

In [335]:
n = len(source)
m = len(target)
scores = np.zeros((n, m), dtype=np.float32)
distance = np.zeros((n, m), dtype=np.float32)
for i, ast1 in enumerate(source):
    for j, ast2 in enumerate(target):
        score, dis = compute_tsed(ast1, ast2, build_type)

TypeError: 'NoneType' object is not iterable

In [236]:
# tree2, node_count2 = build_tree(ast2, build_type)
ast_node = ast2

In [251]:
expr = ast_node.this.this
expr.args.keys()

dict_keys(['kind', 'hint', 'distinct', 'expressions', 'limit', 'from', 'joins', 'where'])

In [351]:
db_id

'law_episode'

In [265]:
from sqlglot.dialects.dialect import Dialect
from src.parsing_sql import _format_select, _format_expression

In [352]:
ss = sqlglot.parse_one("""
SELECT t3.years, t3.episode_id FROM ( SELECT DISTINCT T2.year AS years, T2.episode_id, row_number() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm FROM Person AS T1 INNER JOIN Award AS T2 ON T1.person_id = T2.person_id WHERE T2.award = 'Television' AND T2.award_category = 'Silver Gavel Award' AND T1.name = 'Constantine Makris' AND T2.result = 'Winner' AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts' ) AS T3 GROUP BY t3.episode_id HAVING COUNT(t3.years - t3.rm) >= 2
""")

In [357]:
orders = ss.args.get('from').this.this.args.get('expressions')[-1].this.args.get('order')
orders

Order(
  expressions=[
    Ordered(
      this=Column(
        this=Identifier(this=year, quoted=False),
        table=Identifier(this=T2, quoted=False)),
      nulls_first=True)])

In [359]:
for o in orders:
    break

In [363]:
name, new_expr = _format_expression(
    o, 
    {'table': {}, 'column': {}}, 
    schema, 
    True
)

In [365]:
exp.Order(expressions=[])

Order(
  )

In [344]:
sql_i

"SELECT T2.Consumption FROM transactions_1k AS T1 INNER JOIN yearmonth AS T2 ON T1.CustomerID = T2.CustomerID WHERE T1.Price / T1.Amount > 29.00 AND T1.ProductID = 5 AND T2.Date = '201208'"

In [343]:
expr.expressions[2]

Window(
  this=RowNumber(),
  partition_by=[
    Column(
      this=Identifier(this=episode_id, quoted=False),
      table=Identifier(this=award, quoted=False))],
  order=[
    Column(
      this=Identifier(this=year, quoted=False),
      table=Identifier(this=award, quoted=False))],
  over=OVER)

In [350]:
ss = exp.Window(
    this=exp.RowNumber(),
    partition_by=[exp.Column(this=exp.Identifier(this='c1'))],
    order=exp.Column(this=exp.Identifier(this='c2')),
)

ss.sql()

'ROW_NUMBER() OVER (PARTITION BY c1 c2 ORDER BY )'

In [338]:
_format_select(expr)

AttributeError: 'list' object has no attribute 'args'

In [303]:
Dialect.get_or_raise('sqlite').generate(expr.args.get('expressions'))

AttributeError: 'list' object has no attribute 'parent'

In [301]:
Dialect.get_or_raise('sqlite').generate(ss.args.get('joins'))

AttributeError: 'list' object has no attribute 'parent'

In [272]:
expr.args.get('expression')

In [None]:
expr.generate

In [269]:
expr.expression

Select(
  distinct=Distinct(),
  expressions=[
    Column(
      this=Identifier(this=year, quoted=False),
      table=Identifier(this=award, quoted=False)),
    Column(
      this=Identifier(this=episode_id, quoted=False),
      table=Identifier(this=award, quoted=False)),
    Window(
      this=RowNumber(),
      partition_by=[
        Column(
          this=Identifier(this=episode_id, quoted=False),
          table=Identifier(this=award, quoted=False))],
      order=[
        Column(
          this=Identifier(this=year, quoted=False),
          table=Identifier(this=award, quoted=False))],
      over=OVER)],
  from=From(
    this=Table(
      this=Identifier(this=Person, quoted=False),
      alias=TableAlias(
        this=Identifier(this=T1, quoted=False)))),
  joins=[
    Join(
      this=Table(
        this=Identifier(this=Award, quoted=False),
        alias=TableAlias(
          this=Identifier(this=T2, quoted=False))),
      kind=INNER,
      on=EQ(
        this=Column(
        

In [260]:
str(expr.args.get('where'))

"WHERE T2.award = 'Television' AND T2.award_category = 'Silver Gavel Award' AND T1.name = 'Constantine Makris' AND T2.result = 'Winner' AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts'"

In [103]:
from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _extract_columns_from_expression,
    _determine_tag,
    _format_expression,
    _get_full_column_name,
    extract_aliases,
    extract_selection,
    extract_aggregation,
    extract_orderby,
    extract_others,
    extract_from_join,
    _extract_aliases_from_select,
    _handle_table_or_subquery,
    Schema
)

---

In [159]:
from src.database import SqliteDatabase
from sqlglot import expressions as exp
from src.db_utils import get_data_dict

from pprint import pprint
# for table in tables:
#     data_dict = get_data_dict(table)
#     break
print(list(bird_samples.keys())[:5])

['movie_platform', 'book_publishing_company', 'retail_complains', 'movies_4', 'codebase_comments']


In [207]:
db_id = 'law_episode'
sample_idx = 76
typ = 'train'
# database = SqliteDatabase(db_file=str(proj_path / 'data' / 'bird' / typ / f'{typ}_databases' / db_id / f'{db_id}.sqlite'))
samples = bird_samples[db_id]
schema = Schema(bird_tables[db_id].db_schema)
sql_i = [x.final.sql for i, x in enumerate(samples) if i == sample_idx][0]
# sql_i = [x for i, x in enumerate(samples) if x.sample_id == 2467][0].final.sql
# sql_i = """
# SELECT c1.TITLE, c1.year, c1.rating FROM movie c1 WHERE c1.year = 2000
# """
# print(sql_i)
print(sqlglot.transpile(sql_i, pretty=True)[0])
# res = database.execute(sql_i)
# print(res)
ei = extract_all(sql_i, schema)
ei

SELECT
  t3.years,
  t3.episode_id
FROM (
  SELECT DISTINCT
    T2.year AS years,
    T2.episode_id,
    ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm
  FROM Person AS T1
  INNER JOIN Award AS T2
    ON T1.person_id = T2.person_id
  WHERE
    T2.award = 'Television'
    AND T2.award_category = 'Silver Gavel Award'
    AND T1.name = 'Constantine Makris'
    AND T2.result = 'Winner'
    AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts'
) AS T3
GROUP BY
  t3.episode_id
HAVING
  COUNT(t3.years - t3.rm) >= 2


defaultdict(set,
            {'aliases': {'table': {'T3': "SELECT DISTINCT T2.year AS years, T2.episode_id, ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm FROM Person AS T1 INNER JOIN Award AS T2 ON T1.person_id = T2.person_id WHERE T2.award = 'Television' AND T2.award_category = 'Silver Gavel Award' AND T1.name = 'Constantine Makris' AND T2.result = 'Winner' AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts'",
               'T1': 'person',
               'T2': 'award'},
              'column': {'years': 'T2.year',
               'rm': 'ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year)'}},
             'distinct': True,
             'limit': False,
             'table_asts': {("(SELECT DISTINCT T2.year AS years, T2.episode_id, ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm FROM Person AS T1 INNER JOIN Award AS T2 ON T1.person_id = T2.person_id WHERE T2.award = 'Television' AND T2.award

In [205]:
sql_i = [x.final.sql for i, x in enumerate(samples) if i == sample_idx][0]

parsed_query = sqlglot.parse_one(sql_i)
aliases = extract_aliases(parsed_query)
subqueries = get_subqueries(parsed_query)
print(len(subqueries), parsed_query.sql(pretty=True))
aliases

2 SELECT
  t3.years,
  t3.episode_id
FROM (
  SELECT DISTINCT
    T2.year AS years,
    T2.episode_id,
    ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm
  FROM Person AS T1
  INNER JOIN Award AS T2
    ON T1.person_id = T2.person_id
  WHERE
    T2.award = 'Television'
    AND T2.award_category = 'Silver Gavel Award'
    AND T1.name = 'Constantine Makris'
    AND T2.result = 'Winner'
    AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts'
) AS T3
GROUP BY
  t3.episode_id
HAVING
  COUNT(t3.years - t3.rm) >= 2


{'table': {'T3': "SELECT DISTINCT T2.year AS years, T2.episode_id, ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year) AS rm FROM Person AS T1 INNER JOIN Award AS T2 ON T1.person_id = T2.person_id WHERE T2.award = 'Television' AND T2.award_category = 'Silver Gavel Award' AND T1.name = 'Constantine Makris' AND T2.result = 'Winner' AND T2.organization = 'American Bar Association Silver Gavel Awards for Media and the Arts'",
  'T1': 'person',
  'T2': 'award'},
 'column': {'years': 'T2.year',
  'rm': 'ROW_NUMBER() OVER (PARTITION BY T2.episode_id ORDER BY T2.year)'}}

In [206]:
anonymize_literal= True

for query in subqueries:
    query = deepcopy(query)
    tables = set()
    from_clause = query.args.get('from')
    if from_clause and isinstance(from_clause, exp.From):
        name, expr = _format_expression(from_clause.this, aliases, schema, anonymize_literal)
        tables.add((name, expr, 'from'))

    join_clause = query.args.get('joins', [])
    if join_clause:
        for join_expr in join_clause:
            left_expr = join_expr.args.get('this')
            left_name, new_left_expr = _format_expression(left_expr, aliases, schema, anonymize_literal=False)
            join_expr.args['this'] = new_left_expr
            table = exp.Table(this=left_expr)
            tables.add((left_name, table, 'join'))


In [194]:
tables = set()
from_clause = query.args.get('from')
if from_clause and isinstance(from_clause, exp.From):
    name, expr = _format_expression(from_clause.this, aliases, schema, anonymize_literal)
    tables.add((name, expr, 'from'))

join_clause = query.args.get('joins', [])
if join_clause:
    for join_expr in join_clause:
        left_expr = join_expr.args.get('this')
        left_name, new_left_expr = _format_expression(left_expr, aliases, schema, anonymize_literal=False)
        join_expr.args['this'] = new_left_expr
        table = exp.Table(this=left_expr)
        tables.add((left_name, table, 'join'))

In [189]:
_format_expression(query, aliases, schema, anonymize_literal=False)

AttributeError: 'list' object has no attribute 'args'

In [86]:
from_expr = parsed_query.args.get('from').this
_format_expression(from_expr, aliases, schema, anonymize_literal=False)

('League AS t2',
 Table(
   this=Identifier(this=League, quoted=False),
   alias=TableAlias(
     this=Identifier(this=t2, quoted=False))))

In [39]:
join_exprs = parsed_query.args.get('joins')
for join_expr in join_exprs:
    left_expr = join_expr.args.get('this')
    left_name, new_left_expr = _format_expression(left_expr, aliases, schema, anonymize_literal=False)
    join_expr.args['this'] = new_left_expr
    right_expr = join_expr.args.get('on')
    right_name, new_right_expr = _format_expression(right_expr, aliases, schema, anonymize_literal=False)
    join_expr.args['on'] = new_right_expr
    break

In [70]:
exp.Table(this=left_expr)

Table(
  this=Subquery(
    this=Select(
      expressions=[
        Column(
          this=Identifier(this=league_id, quoted=False),
          table=Identifier(this=match, quoted=False)),
        Max(
          this=Column(
            this=Identifier(this=cnt, quoted=False),
            table=Identifier(this=subquery_t1, quoted=False)))],
      from=From(
        this=Subquery(
          this=Select(
            expressions=[
              Column(
                this=Identifier(this=league_id, quoted=False)),
              Alias(
                this=Count(
                  this=Column(
                    this=Identifier(this=id, quoted=False)),
                  big_int=True),
                alias=Identifier(this=cnt, quoted=False))],
            from=From(
              this=Table(
                this=Identifier(this=Match, quoted=False))),
            group=Group(
              expressions=[
                Column(
                  this=Identifier(this=league_id, quoted=Fals

In [67]:
f'{left_name.lower()} on {right_name}' 

'(select match.league_id, max(subquery_t1.cnt) from (select league_id, count(id) as cnt from match group by league_id) as subquery) as t1 on __(select league_id, max(cnt) as max_count from (select league_id, count(id) as cnt from match group by league_id) as subquery).league_id__ eq __league.id__'

In [60]:
right_name

'__(select league_id, max(cnt) as max_count from (select league_id, count(id) as cnt from match group by league_id) as subquery).league_id__ eq __league.id__'

In [27]:
expr = left_expr
args = [(k, v) for k, v in expr.args.items() if v]
for arg, sub_expr in args:
    print(arg, type(sub_expr))
    sub_args = [(k, v) for k, v in sub_expr.args.items() if v]
    for k, v in sub_args:
        print(k, type(v))

this <class 'sqlglot.expressions.Select'>
expressions <class 'list'>
from <class 'sqlglot.expressions.From'>
alias <class 'sqlglot.expressions.TableAlias'>
this <class 'sqlglot.expressions.Identifier'>


In [44]:
aliases

{'table': {'t2': 'league',
  't1': 'SELECT league_id, MAX(cnt) AS max_count FROM (SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id) AS subquery',
  'subquery': 'SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id',
  'Match': 'match'},
 'column': {'max_count': 'MAX(cnt)', 'cnt': 'COUNT(id)'}}

In [51]:
_format_expression(deepcopy(left_expr), aliases, schema, anonymize_literal=False)

('(SELECT match.league_id, MAX(subquery_t1.cnt) FROM (SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id) AS subquery) AS t1',
 Subquery(
   this=Select(
     expressions=[
       Column(
         this=Identifier(this=league_id, quoted=False),
         table=Identifier(this=match, quoted=False)),
       Max(
         this=Column(
           this=Identifier(this=cnt, quoted=False),
           table=Identifier(this=subquery_t1, quoted=False)))],
     from=From(
       this=Subquery(
         this=Select(
           expressions=[
             Column(
               this=Identifier(this=league_id, quoted=False)),
             Alias(
               this=Count(
                 this=Column(
                   this=Identifier(this=id, quoted=False)),
                 big_int=True),
               alias=Identifier(this=cnt, quoted=False))],
           from=From(
             this=Table(
               this=Identifier(this=Match, quoted=False))),
           group=Group(
          

In [49]:
aliases

{'table': {'t2': 'league',
  't1': 'SELECT league_id, MAX(cnt) AS max_count FROM (SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id) AS subquery',
  'subquery': 'SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id',
  'Match': 'match'},
 'column': {'max_count': 'MAX(cnt)', 'cnt': 'COUNT(id)'}}

In [47]:
str(left_expr)

'(SELECT match.league_id, MAX(subquery_t1.cnt) FROM (SELECT league_id, COUNT(id) AS cnt FROM Match GROUP BY league_id) AS subquery) AS t1'

In [88]:
for query in subqueries:
    break
    sel_cols, sel_types = extract_selection(query, aliases, schema, False)
    conds, op_types = extract_condition(query, aliases, schema)
    agg_cols, agg_types  = extract_aggregation(query, aliases, schema)
    orderby_cols, orderby_asts = extract_orderby(query, aliases, schema)
    others = extract_others(query)
    # pass

Select

In [89]:
unique_columns = set()
selection_asts = set()

for select_exp in query.select():
    # break
    continue
    print(select_exp, isinstance(expr, (exp.Column, exp.Star)))
    expr_str, expr = _format_expression(select_exp, aliases, schema, True)
    print(expr_str)
    tag = _determine_tag(expr)
    selection_asts.add((expr_str, expr, tag))
    columns = _extract_columns_from_expression(expr, aliases, schema)
    unique_columns.update(columns)

In [90]:
aliases

{'table': {'T': "SELECT p_name, p_size FROM part WHERE p_name IN ('pink powder drab lawn cyan', 'cornflower sky burlywood green beige')",
  'part': 'part'},
 'column': {}}

In [91]:
select_exp

Column(
  this=Identifier(this=p_name, quoted=False),
  table=Identifier(this=T, quoted=False))

In [116]:
expr = select_exp
col = expr
column_name = col.name.lower()
table_alias = col.table if col.table else None
real_table_name = aliases['table'][table_alias]

In [117]:
real_table_name

"SELECT p_name, p_size FROM part WHERE p_name IN ('pink powder drab lawn cyan', 'cornflower sky burlywood green beige')"

In [118]:
_get_full_column_name(deepcopy(select_exp), aliases, schema)

"__(select p_name, p_size from part where p_name in ('pink powder drab lawn cyan', 'cornflower sky burlywood green beige')).p_name__"

In [119]:
original_table = aliases['table'][expr.table]
quoted = expr.args['table'].quoted
quoted

False

In [120]:
_format_expression(deepcopy(select_exp), aliases, schema, True)

("__(select p_name, p_size from part where p_name in ('pink powder drab lawn cyan', 'cornflower sky burlywood green beige')).p_name__",
 Column(
   this=Identifier(this=p_name, quoted=False),
   table=Subquery(
     this=Select(
       expressions=[
         Column(
           this=Identifier(this=p_name, quoted=False)),
         Column(
           this=Identifier(this=p_size, quoted=False))],
       from=From(
         this=Table(
           this=Identifier(this=part, quoted=False))),
       where=Where(
         this=In(
           this=Column(
             this=Identifier(this=p_name, quoted=False)),
           expressions=[
             Literal(this=pink powder drab lawn cyan, is_string=True),
             Literal(this=cornflower sky burlywood green beige, is_string=True)]))))))

In [136]:
for alias in parsed_query.find_all(exp.Alias):
    print(alias.alias, str(alias.this), alias.args.get('this').args.get('table'))

east COUNT("Order ID") None
west (SELECT COUNT("Order ID") FROM west_superstore WHERE "Order Date" LIKE '2015%') None


In [113]:
expr_str, expr = _format_expression(select_exp.this, aliases, schema, True)
expr_str

'[placeholder-subquery]'

In [115]:
select_exp

Column(
  this=Subquery(
    this=Select(
      expressions=[
        Count(
          this=Column(
            this=Identifier(this=order id, quoted=True),
            table=Identifier(this=east_superstore, quoted=False)),
          big_int=True)],
      from=From(
        this=Table(
          this=Identifier(this=west_superstore, quoted=False))),
      where=Where(
        this=Like(
          this=Column(
            this=Identifier(this=order date, quoted=True)),
          expression=Literal(this=2015%, is_string=True))))))

Where

In [24]:
condition_asts = set()
operator_types = set()

for clause_name in ("where", "having"):
    clause = query.args.get(clause_name)
    if clause:
        # clause is WHERE or HAVING. We only want the conditions inside.
        # clause.this is the actual condition expression (e.g., AND/OR tree).
        break
        conds = _extract_conditions(
            clause.this, aliases, schema, operator_types, anonymize_literal=True
        )
        if not isinstance(conds, list):
            conds = [conds]
        condition_asts.update(conds)
        assert len(condition_asts) > 0, f"Failed to extract conditions from {clause_name} clause {len(condition_asts)}"

In [25]:
aliases

{'table': {'t4': 'players_teams',
  't5': 'players',
  't1': 'teams',
  't2': 'teams',
  't3': 'teams'},
 'column': {}}

In [38]:
expr = clause.this.expression.args.get('query').this
expr.args.get('distinct').args

{'on': None}

order by

In [162]:
for query in subqueries:
    sel_cols, sel_types = extract_selection(query, aliases, schema, False)
    conds, op_types = extract_condition(query, aliases, schema)
    agg_cols, agg_types  = extract_aggregation(query, aliases, schema)
    orderby_cols, orderby_asts = extract_orderby(query, aliases, schema)
    others = extract_others(query)
    pass

In [161]:
unique_columns = set()
otherby_asts = set()

order = query.args.get('order')
if order and isinstance(order, exp.Order):
    # order_node.expressions is a list of Ordered expressions
    for order_expr in order.expressions:
        # ordered_expr.this is the actual expression being ordered
        expr_str, expr = _format_expression(
            order_expr.this, aliases, schema, True
        )
        tag = _determine_tag(expr)
        otherby_asts.add((expr_str, expr, tag))
        columns = _extract_columns_from_expression(expr, aliases, schema)
        unique_columns.update(columns)


In [227]:
expr = clause.this
left = expr.args.get('this')
right = expr.args.get('expression')
expr = right.expression
expr

Column(
  this=Identifier(this=T1, quoted=False))

In [35]:
args = [(k, type(v)) for k, v in expr.args.items() if v]
args

[('distinct', sqlglot.expressions.Distinct),
 ('expressions', list),
 ('from', sqlglot.expressions.From),
 ('joins', list),
 ('where', sqlglot.expressions.Where)]

In [232]:
_format_expression(expr, aliases, schema, remove_alias=True, anonymize_literal=True)

('(SELECT MAX(MAX(FTAG), MAX(FTHG)) FROM matchs WHERE season = 2021)',
 Column(
   this=Subquery(
     this=Select(
       expressions=[
         Max(
           this=Max(
             this=Column(
               this=Identifier(this=FTAG, quoted=False))),
           expressions=[
             Max(
               this=Column(
                 this=Identifier(this=FTHG, quoted=False)))])],
       from=From(
         this=Table(
           this=Identifier(this=matchs, quoted=False))),
       where=Where(
         this=EQ(
           this=Column(
             this=Identifier(this=season, quoted=False)),
           expression=Literal(this=2021, is_string=False)))))))

In [230]:
right.args['this']

Column(
  this=Identifier(this=FTAG, quoted=False))

In [138]:
sqlglot.parse_one(aliases['column'][column_name].strip('(').strip(')'))

Select(
  expressions=[
    Max(
      this=Max(
        this=Column(
          this=Identifier(this=FTAG, quoted=False))),
      expressions=[
        Max(
          this=Column(
            this=Identifier(this=FTHG, quoted=False)))])],
  from=From(
    this=Table(
      this=Identifier(this=matchs, quoted=False))),
  where=Where(
    this=EQ(
      this=Column(
        this=Identifier(this=season, quoted=False)),
      expression=Literal(this=2021, is_string=False))))

In [186]:
exp.Column(this=exp.Subquery(this=sqlglot.parse_one(aliases['column'][column_name].strip('(').strip(')'))))

Column(
  this=Subquery(
    this=Select(
      expressions=[
        Max(
          this=Max(
            this=Column(
              this=Identifier(this=FTAG, quoted=False))),
          expressions=[
            Max(
              this=Column(
                this=Identifier(this=FTHG, quoted=False)))])],
      from=From(
        this=Table(
          this=Identifier(this=matchs, quoted=False))),
      where=Where(
        this=EQ(
          this=Column(
            this=Identifier(this=season, quoted=False)),
          expression=Literal(this=2021, is_string=False))))))

In [231]:
name = _get_full_column_name(expr, aliases, schema)
name

'(SELECT MAX(MAX(FTAG), MAX(FTHG)) FROM matchs WHERE season = 2021)'

In [175]:
col = right.expression

column_name = col.name.lower()
if (not schema.check_column_exist(column_name)) and (column_name not in aliases['column']):
    assert False, f"Column {column_name} not found in schema"

# If the column is *, map directly to __all__
if column_name == '*':
    print(schema.idMap['*'])

table_alias = col.table.lower() if col.table else None

In [176]:
possible_tables = [v for v in aliases['table'].values() if 'select' not in v.lower()]
real_table_name = schema.get_table_name(column_name, possible_tables)
possible_tables = [k for k, v in aliases['table'].items() if 'select' in v.lower()]
possible_tables

[]

In [125]:
aliases['table']

{'matchs': 'matchs'}

In [124]:
list(aliases['table'].items())

[('matchs', 'matchs')]

In [55]:
_format_expression(right, aliases, schema,  anonymize_literal=True)

None


AssertionError: Table not found for column rating_score rating_score

In [42]:
_extract_conditions(
    left, aliases, schema, operator_types, anonymize_literal=True
)

('__ratings_users.user_id__',
 Column(
   this=Identifier(this=user_id, quoted=False),
   table=Identifier(this=ratings_users, quoted=False)))

In [40]:
schema.schema[possible_tables[0]]

{'user_id': 'text',
 'list_id': 'text',
 'list_update_date_utc': 'text',
 'list_creation_date_utc': 'text',
 'user_trialist': 'text',
 'user_subscriber': 'text',
 'user_avatar_image_url': 'text',
 'user_cover_image_url': 'text',
 'user_eligible_for_trial': 'text',
 'user_has_payment_method': 'text'}

In [30]:
name = _get_full_column_name(left, aliases, schema)
name

'__list_title__'

In [179]:
_format_expression(left, aliases, schema, remove_alias=True, anonymize_literal=False)

('__user_reviews.app__',
 Column(
   this=Identifier(this=App, quoted=False),
   table=Identifier(this=user_reviews, quoted=False)))

In [176]:
'user_reviews' in schema.schema

True

In [171]:
left.args.get('table') and left.table.lower() in aliases['table']

In [26]:
unique_columns = set()
otherby_asts = set()
order = query.args.get('order')
for order_expr in order.expressions:
    # ordered_expr.this is the actual expression being ordered
    columns = _extract_columns_from_expression(order_expr.this, aliases, schema)
    unique_columns.update(columns)

    tag = _determine_tag(order_expr)
    expr_str, expr = _format_expression(order_expr.this, aliases, schema, remove_alias=True, anonymize_literal=False)
    otherby_asts.add((expr_str, expr, tag))


In [27]:
otherby_asts

{('count(__movies.movie_title__)',
  Count(
    this=Column(
      this=Identifier(this=movie_title, quoted=False)),
    big_int=True),
  '<f>')}

In [25]:
isinstance(expr.this, exp.Expression)

True

In [278]:
conditions = set()
operator_types = set()

for clause_name in ("where", "having"):
    clause = query.args.get(clause_name)
    if clause:
        # clause is WHERE or HAVING. We only want the conditions inside.
        # clause.this is the actual condition expression (e.g., AND/OR tree).
        # conds = _extract_conditions(
        #     clause.this, aliases, schema, operator_types
        # )
        # conditions.update(conds)
        break

In [273]:
expr = clause.this

conditions = []
left = expr.args.get('this')
right = expr.args.get('expression')
# if left:
#     left_cond = _extract_conditions(left, aliases, schema, operator_types)
#     if isinstance(left_cond, str):
#         left_cond = [left_cond]
#     conditions.extend(left_cond)
# if right:
#     right_cond = _extract_conditions(right, aliases, schema, operator_types)
#     if isinstance(right_cond, str):
#         right_cond = [right_cond]
#     conditions.extend(right_cond)

In [282]:
_extract_conditions(clause.this, aliases, schema, operator_types, anonymize_literal=True)

('cast(substring(__ratings.rating_timestamp_utc__, [placeholder-type:numeric], [placeholder-type:numeric]), int) eq [placeholder-type:numeric]',
 EQ(
   this=Cast(
     this=Substring(
       this=Column(
         this=Identifier(this=rating_timestamp_utc, quoted=False),
         table=Identifier(this=ratings, quoted=False)),
       start=Literal(this=[placeholder-type:numeric], is_string=False),
       length=Literal(this=[placeholder-type:numeric], is_string=False)),
     to=DataType(this=Type.INT, nested=False)),
   expression=Literal(this=[placeholder-type:numeric], is_string=False)))

In [258]:
left_cond, new_expr = _format_expression(left.this, aliases, schema, remove_alias=True, anonymize_literal=True)
left_cond, new_expr

('cast(substring(__ratings.rating_timestamp_utc__, [PLACEHOLDER-TYPE:NUMERIC], [PLACEHOLDER-TYPE:NUMERIC]), int) eq [PLACEHOLDER-TYPE:NUMERIC]',
 EQ(
   this=Cast(
     this=Substring(
       this=Column(
         this=Identifier(this=rating_timestamp_utc, quoted=False),
         table=Identifier(this=ratings, quoted=False)),
       start=Literal(this=[PLACEHOLDER-TYPE:NUMERIC], is_string=False),
       length=Literal(this=[PLACEHOLDER-TYPE:NUMERIC], is_string=False)),
     to=DataType(this=Type.INT, nested=False)),
   expression=Literal(this=[PLACEHOLDER-TYPE:NUMERIC], is_string=False)))

In [242]:
new_expr

EQ(
  this=Cast(
    this=Substring(
      this=Column(
        this=Identifier(this=rating_timestamp_utc, quoted=False),
        table=Identifier(this=ratings, quoted=False)),
      start=Literal(this=[PLACEHOLDER-TYPE:NUMERIC], is_string=False),
      length=Literal(this=[PLACEHOLDER-TYPE:NUMERIC], is_string=False)),
    to=DataType(this=Type.INT, nested=False)),
  expression=Literal(this=2020, is_string=False))

In [65]:
sql1 = list(sel_types)[0][1]
sql2 = list(sel_types)[2][1]
print(sql1)
print(sql2)

SUM(ratings.rating_score)
movies.movie_title


In [None]:
from src.eval_utils import compute_tsed
# from collections import defaultdict
# from sqlglot.diff import _get_leaves, _is_same_type

sql1 = sqlglot.parse_one("SELECT c1 FROM t1 WHERE a = 1 AND b = 2")
sql2 = sqlglot.parse_one("SELECT c1,c2 FROM t1 WHERE a = 2 AND b = 2")

sql_i

Remove(expression=Literal(this=1, is_string=False))
Insert(expression=Literal(this=2, is_string=False))
Insert(expression=Column(
  this=Identifier(this=c2, quoted=False)))
Move(expression=Literal(this=2, is_string=False))


defaultdict(int,
            {sqlglot.diff.Remove: 1,
             sqlglot.diff.Insert: 2,
             sqlglot.diff.Keep: 11,
             sqlglot.diff.Move: 1})

In [43]:
from zss import simple_distance, Node
from sqlglot import expressions as exp
from apted import APTED
from apted.helpers import Tree

def build_tree(ast_node: exp.Query, build_type: str):
    """Build a tree from an AST node.
    
    Args:
        ast_node (exp.Expression): The root AST node.
        build_type (str): The type of tree to build (zss or apted).
    """
    tree_node, node_count = _build_tree(ast_node, build_type)
    if build_type == 'apted':
        tree_node = Tree.from_text(tree_node + '}')
    return tree_node, node_count

def _build_tree(ast_node: exp.Query, build_type: str):
    tree_node = _build_node(ast_node, build_type)
    node_count = 1
    # Recursively add children and count nodes
    for child in ast_node.args.values():
        if isinstance(child, exp.Expression):
            child_node, child_count = _build_tree(child, build_type)
            tree_node = _add_child(child_node, tree_node, build_type)
            node_count += child_count
        elif isinstance(child, list):
            for sub_child in child:
                if isinstance(sub_child, exp.Expression):
                    sub_child_node, sub_child_count = _build_tree(sub_child, build_type)
                    tree_node = _add_child(sub_child_node, tree_node, build_type)
                    node_count += sub_child_count
    return tree_node, node_count

def _build_node(ast_node: exp.Expression, build_type: str):
    node_name = f'{ast_node.key}({str(ast_node)})'
    if build_type == 'zss':
        return Node(node_name)
    elif build_type == 'apted':
        return '{' + node_name
    else:
        raise ValueError(f"Invalid build type: {build_type} (zss or apted)")
    
def _add_child(child_node: exp.Expression|str, parent_node: Node|str, build_type: str):
    if build_type == 'zss':
        parent_node.addkid(child_node)
    elif build_type == 'apted':
        parent_node += child_node + '}'
    else:
        raise ValueError(f"Invalid build type: {build_type} (zss or apted)")
    return parent_node

def stringify_zsstree(node, level=0):
    result = "  " * level + node.label + "\n"
    for child in node.children:
        result += stringify_zsstree(child, level + 1)
    return result

def compute_tsed(sql1, sql2, build_type='zss'):
    tree1, node_count1 = build_tree(sql1, build_type)
    tree2, node_count2 = build_tree(sql2, build_type)
    if build_type == 'zss':
        distance = simple_distance(tree1, tree2)
    elif build_type == 'apted':
        distance = APTED(tree1, tree2).compute_edit_distance()
    tsed = max(1-distance/max(node_count1,node_count2), 0)
    return tsed, distance
    
build_type = 'apted' # zss or apted
# sql1 = sqlglot.parse_one("SELECT c1 FROM t1 WHERE a = 1 AND b = 2")
# sql2 = sqlglot.parse_one("SELECT c1,c2 FROM t1 WHERE a = 2 AND b = 2")
sql1 = sqlglot.parse_one(sql_i)
sql2 = sqlglot.parse_one("SELECT c1 FROM t1 WHERE a = 1 AND b = 2")
tsed, distance = compute_tsed(sql1, sql2, build_type)
print(f'TSED: {tsed:.4f}, Distance: {distance}')

TSED: 0.0182, Distance: 54


In [66]:
for diff in d:
    if isinstance(diff, Move):
        print(diff)

Move(expression=Literal(this=2, is_string=False))


In [56]:
_extract_conditions(left, aliases, schema, operator_types)

'cast(substring(__ratings.rating_timestamp_utc__)) eq 2020'

In [102]:
unique_columns = set()
selection_types = set()

for select_exp in query.select():
    print(select_exp)
    columns_in_item = _extract_columns_from_expression(select_exp, aliases, schema)
    unique_columns.update(columns_in_item)

    tag = _determine_tag(select_exp)
    expr_str = _format_expression(select_exp, aliases, schema, remove_alias=True)
    selection_types.add((expr_str, tag))

CASE WHEN 'Studio Entertainment[NI 1]' > 'Disney Media Networks' THEN 'Studio Entertainment[NI 1]' ELSE 'Disney Media Networks' END


In [110]:
_extract_columns_from_expression(select_exp, aliases, schema)

set()

In [114]:
select_exp

Case(
  ifs=[
    If(
      this=GT(
        this=Literal(this=Studio Entertainment[NI 1], is_string=True),
        expression=Literal(this=Disney Media Networks, is_string=True)),
      true=Literal(this=Studio Entertainment[NI 1], is_string=True))],
  default=Literal(this=Disney Media Networks, is_string=True))

In [113]:
type(select_exp).__bases__

(sqlglot.expressions.Func,)

In [112]:
len(list(select_exp.find_all(exp.Column, exp.Star)))

0

In [111]:
for col in select_exp.find_all(exp.Column, exp.Star):
    break

col

Column(
  this=Identifier(this=Complaint ID, quoted=True))

In [108]:
col

Column(
  this=Identifier(this=Complaint ID, quoted=True))

In [62]:
_get_full_column_name(col, aliases, schema)

'__complaint id__'

In [65]:
schema.schema

{'state': {'StateCode': 'text', 'State': 'text', 'Region': 'text'},
 'callcenterlogs': {'Date received': 'text',
  'Complaint ID': 'text',
  'rand client': 'text',
  'phonefinal': 'text',
  'vru+line': 'text',
  'call_id': 'text',
  'priority': 'text',
  'type': 'text',
  'outcome': 'text',
  'server': 'text',
  'ser_start': 'text',
  'ser_exit': 'text',
  'ser_time': 'text'},
 'client': {'client_id': 'text',
  'sex': 'text',
  'day': 'text',
  'month': 'text',
  'year': 'text',
  'age': 'text',
  'social': 'text',
  'first': 'text',
  'middle': 'text',
  'last': 'text',
  'phone': 'text',
  'email': 'text',
  'address_1': 'text',
  'address_2': 'text',
  'city': 'text',
  'state': 'text',
  'zipcode': 'text',
  'district_id': 'text'},
 'district': {'district_id': 'text',
  'city': 'text',
  'state_abbrev': 'text',
  'division': 'text'},
 'events': {'Date received': 'date',
  'Product': 'date',
  'Sub-product': 'date',
  'Issue': 'date',
  'Sub-issue': 'date',
  'Consumer complaint nar

In [30]:
database.execute('''
SELECT `Complaint ID` FROM callcenterlogs WHERE `Complaint ID` = "CR0656522" ORDER BY ser_time DESC LIMIT 3
''')

Unnamed: 0,Complaint ID
0,CR0656522


In [31]:
sqlglot.parse_one('''
SELECT "Complaint ID" FROM callcenterlogs WHERE "Complaint ID" = "CR0656522" ORDER BY ser_time DESC LIMIT 3
''')

Select(
  expressions=[
    Column(
      this=Identifier(this=Complaint ID, quoted=True))],
  limit=Limit(
    expression=Literal(this=3, is_string=False)),
  from=From(
    this=Table(
      this=Identifier(this=callcenterlogs, quoted=False))),
  where=Where(
    this=EQ(
      this=Column(
        this=Identifier(this=Complaint ID, quoted=True)),
      expression=Column(
        this=Identifier(this=CR0656522, quoted=True)))),
  order=Order(
    expressions=[
      Ordered(
        this=Column(
          this=Identifier(this=ser_time, quoted=False)),
        desc=True,
        nulls_first=False)]))

In [43]:
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from src.pymodels import SQLResponse
   
prompt = PromptTemplate(
    template=Prompts.zero_shot_inference_bird,
    input_variables=['schema', 'input_query', 'evidence']
)
model_name = 'gpt-4o-mini'
model_openai = ChatOpenAI(
    model=model_name,
    temperature=0.0,
    logprobs=True,
    top_logprobs=5
)

# model = model_openai.with_structured_output(SQLResponse, include_raw=True)
chain = (prompt | model_openai)

In [None]:
eval_path = proj_path / 'experiments' / 'zero_shot' / 'bird'
if not eval_path.exists():
    eval_path.mkdir(parents=True)

# run zero-shot SQL generation
results = {}
iterator = tqdm(samples, total=len(samples))
for i, sample in enumerate(iterator):
    db_id = sample.db_id
    iterator.set_description(f"Processing {db_id} - {sample.sample_id}")
    schema = get_schema_str(
        schema=bird_tables[db_id].db_schema,
        foreign_keys=bird_tables[db_id].foreign_keys,
        primary_keys=bird_tables[db_id].primary_keys,
        col_explanation=all_descriptions[db_id]    
    )
    output = chain.invoke(input={
        'schema': schema,
        'input_query': sample.final.question,
        'evidence': sample.evidence
    })
    o = SQLResponse(**json.loads(output.content))
    usage = output.usage_metadata
    logprobs = output.response_metadata['logprobs']['content']
    results[sample.sample_id] = {
        'sample_id': sample.sample_id,
        'output': {
            'sql': o.full_sql_query,
            'rationale': o.rationale,
        },
        'usage': usage,
        'logprobs': logprobs
    }

Processing movie_platform - 8:   0%|          | 8/8731 [00:41<12:42:35,  5.25s/it]


KeyboardInterrupt: 

In [None]:
# detect token indices after `full_sql_query`
txt = ''
sql_tokens = []
start = False
for i, x in enumerate(logprobs):
    txt += x['token']
    if 'full_sql_query' in txt:
        if x['token'] == 'SELECT':
            start = True
            txt = ''

    if start:
        sql_tokens.append(x)

In [83]:
''.join([x['token'] for x in sql_tokens][:-1])

'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC;"\n'