In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
import sqlglot
import sqlglot.expressions as exp
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import DatabaseModel, QuestionSQL, SparcSample, SpiderSample, Description
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)

In [None]:
bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)

all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)
skip = [622, 6916, 6917, 6930, 6967, 6987]
bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2)
# makesure the dev/test sql is not in the train sql

save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
save_samples_spider_bird(dev_samples+test_samples, proj_path / 'data' / 'bird_dev.json')
print(len(train_samples), len(dev_samples), len(test_samples))

  0%|          | 0/10956 [00:00<?, ?it/s]

100%|██████████| 10956/10956 [00:03<00:00, 2953.83it/s]


6535 2155 2260


In [5]:
df_train = pd.DataFrame({
    'db_id': [x.db_id for x in train_samples], 
    'sql': [x.final.sql for x in train_samples],
    'sample_id': [x.sample_id for x in train_samples]
})
df_dev = pd.DataFrame({
    'db_id': [x.db_id for x in dev_samples] + [x.db_id for x in test_samples],
    'sql': [x.final.sql for x in dev_samples] + [x.final.sql for x in test_samples],
    'sample_id': [x.sample_id for x in dev_samples] + [x.sample_id for x in test_samples]
})

In [99]:
df_train.to_csv(proj_path / 'data' / 'bird_train.csv', index=False)
df_dev.to_csv(proj_path / 'data' / 'bird_dev.csv', index=False)

In [23]:
df_train.groupby('db_id').size()

db_id
address                          90
airline                          55
app_store                        37
authors                         104
beer_factory                     78
                               ... 
university                       90
video_games                     120
works_cycles                    284
world                            59
world_development_indicators     94
Length: 79, dtype: int64

In [24]:
df_dev.groupby('db_id').size()

db_id
address                          60
airline                          37
app_store                        26
authors                          70
beer_factory                     53
                               ... 
university                       60
video_games                      81
works_cycles                    190
world                            40
world_development_indicators     63
Length: 79, dtype: int64

In [6]:
import sqlparse
from src.parsing_sql import (
    extract_aliases, 
    extract_selection, 
    extract_condition,
    extract_aggregation,
    extract_nested_setoperation,
    extract_others,
    Schema
)
from src.eval_complexity import (
    partial_match
)
import spacy
nlp_spacy = spacy.load('en_core_web_md')

In [34]:
for i, row in sql_db_id_dev.iterrows():
    break

In [35]:
row

sql          SELECT COUNT(T1.TransactionID) FROM transactio...
sample_id                                                10930
Name: 2143, dtype: object

In [37]:
db_ids = df_dev['db_id'].unique()
error_ids = []
results = defaultdict()
for db_id in db_ids:
    sql_db_id_dev = df_dev.loc[df_dev['db_id'] == db_id, ['sql', 'sample_id']]
    sql_db_id_train = df_train.loc[df_train['db_id'] == db_id, ['sql', 'sample_id']]
    # spacy similarity
    sql_dev = sql_db_id_dev['sql'].apply(lambda x: nlp_spacy(x))
    sql_train = sql_db_id_train['sql'].apply(lambda x: nlp_spacy(x))
    semantic_sim = np.zeros((len(sql_dev), len(sql_train)))
    for i, sql_d in enumerate(sql_dev):
        for j, sql_t in enumerate(sql_train):
            semantic_sim[i, j] = sql_d.similarity(sql_t)

    # structural similarity
    schema = Schema(bird_tables[db_id].db_schema)
    
    for i, (_, row_d) in tqdm(enumerate(sql_db_id_dev.iterrows()), total=len(sql_db_id_dev), desc=f'DB {db_id}'):
        sql_d = row_d['sql']
        sid_d = row_d['sample_id']
        struct_sim = np.zeros((len(sql_dev), len(sql_train)))
        try:
            stmt_d = sqlparse.parse(sql_d)[0]
            alias_d = extract_aliases(stmt_d)
            
            sel_d = extract_selection(stmt_d, alias_d, schema)[0]
            cond_d = extract_condition(stmt_d, alias_d, schema)[0]
            agg_d = extract_aggregation(stmt_d, alias_d, schema)[0]
            nested_d = extract_nested_setoperation(stmt_d)
            others_d = extract_others(stmt_d, alias_d, schema)
            distinct_d = others_d.get('distinct', False)
            orderby_d = others_d.get('order by', False)
            limit_d = others_d.get('limit', False)
        except Exception as e:
            error_ids.append(('dev', sid_d, str(e)))
            continue

        for j, (_, row_t) in enumerate(sql_db_id_train.iterrows()):
            sql_t = row_t['sql']
            sid_t = row_t['sample_id']
            try:    
                stmt_t = sqlparse.parse(sql_t)[0]
                alias_t = extract_aliases(stmt_t)

                sel_t = extract_selection(stmt_t, alias_t, schema)[0]
                cond_t = extract_condition(stmt_t, alias_t, schema)[0]
                agg_t = extract_aggregation(stmt_t, alias_t, schema)[0]
                nested_t = extract_nested_setoperation(stmt_t)
                others_t = extract_others(stmt_t, alias_t, schema)
                distinct_t = others_t.get('distinct', False)
                orderby_t = others_t.get('order by', False)
                limit_t = others_t.get('limit', False)
            except Exception as e:
                if ('tr', sid_t, str(e)) not in error_ids:
                    error_ids.append(('tr', sid_t, str(e)))
                continue

            sel_iou, *_ = partial_match(sel_d, sel_t)
            cond_iou, *_ = partial_match(cond_d, cond_t)
            agg_iou, *_ = partial_match(agg_d, agg_t)
            dis = int(distinct_d == distinct_t)
            ord = int(orderby_d == orderby_t)
            lim = int(limit_d == limit_t)
            struct_sim[i, j] = (sel_iou + cond_iou + agg_iou + (dis + ord + lim)/3)/4

    results[db_id] = {
        'sem': semantic_sim,
        'struct': struct_sim
    }  

DB movie_platform: 100%|██████████| 67/67 [00:09<00:00,  7.27it/s]
DB book_publishing_company: 100%|██████████| 30/30 [00:01<00:00, 16.35it/s]
DB retail_complains: 100%|██████████| 68/68 [00:10<00:00,  6.38it/s]
DB movies_4: 100%|██████████| 64/64 [00:08<00:00,  7.42it/s]
DB codebase_comments: 100%|██████████| 50/50 [00:04<00:00, 10.99it/s]
DB trains: 100%|██████████| 16/16 [00:00<00:00, 37.06it/s]
DB movie: 100%|██████████| 19/19 [00:00<00:00, 26.09it/s]
DB social_media: 100%|██████████| 32/32 [00:01<00:00, 19.19it/s]
DB cs_semester: 100%|██████████| 46/46 [00:04<00:00,  9.75it/s]
DB computer_student: 100%|██████████| 29/29 [00:01<00:00, 19.20it/s]
DB talkingdata: 100%|██████████| 83/83 [00:13<00:00,  6.08it/s]
DB law_episode: 100%|██████████| 46/46 [00:04<00:00, 10.71it/s]
DB synthea: 100%|██████████| 74/74 [00:13<00:00,  5.46it/s]
DB car_retails: 100%|██████████| 51/51 [00:06<00:00,  8.40it/s]
DB restaurant: 100%|██████████| 47/47 [00:03<00:00, 12.26it/s]
DB soccer_2016: 100%|██████

In [100]:
import pickle 
with open(proj_path / 'data' / 'errors_bird.pkl', 'wb') as f:
    pickle.dump(error_ids, f)

In [67]:
for typ, sample_id, msg in set(error_ids):
    if typ == 'dev':
        df = df_dev
    else:
        df = df_train
    
    sql = df_dev.loc[df_dev['sample_id'] == 7755, 'sql'].values[0]
    statement = sqlparse.parse(sql)[0]
    aliases = extract_aliases(statement)
    
    sel_d = extract_selection(statement, aliases, schema)[0]
    cond_d = extract_condition(statement, aliases, schema)[0]
    agg_d = extract_aggregation(statement, aliases, schema)[0]
    nested_d = extract_nested_setoperation(statement)
    others_d = extract_others(statement, aliases, schema)
    distinct_d = others_d.get('distinct', False)
    orderby_d = others_d.get('order by', False)
    limit_d = others_d.get('limit', False)
    
    break

AttributeError: 'NoneType' object has no attribute 'lower'

In [68]:
df_dev.loc[df_dev['sample_id'] == 7755]

Unnamed: 0,db_id,sql,sample_id
1517,hockey,SELECT name FROM Teams WHERE year = 2006 GROUP...,7755


In [69]:
aliases

{'table': {'teams': 'teams'}, 'column': {}}

In [61]:
sql

'SELECT name FROM Teams WHERE year = 2006 GROUP BY tmID, name ORDER BY CAST(SUM(BenchMinor) AS REAL) / 2 DESC LIMIT 1'

In [59]:
sel_d

{'__name__'}

In [None]:
sql = """
SELECT name 
FROM Teams 
WHERE year = 2006 
GROUP BY tmID, name 
ORDER BY CAST(SUM(BenchMinor) AS REAL) / 2 DESC LIMIT 1
"""
statement = sqlparse.parse(sql)[0]

In [70]:
from src.parsing_sql import (
    split_set_operation, 
    tks, 
    is_pwn, 
    get_full_column_name,
    Identifier, TokenList,
    get_orderby_expression
)

In [71]:
statements = split_set_operation(statement)
for stmt in statements:
    print(stmt)

SELECT name FROM Teams WHERE year = 2006 GROUP BY tmID, name ORDER BY CAST(SUM(BenchMinor) AS REAL) / 2 DESC LIMIT 1


In [93]:
others = {'distinct': set(), 'order by': set(), 'limit': False}
distinct_used = False
order_by_used = False
orderby_tokens = []
ord = []
for token in stmt.flatten():
    if token.ttype is tks.Keyword:
        if token.value.upper() == 'DISTINCT':
            distinct_used = True
            continue
        if token.value.upper() == 'ORDER BY':
            ord = []
            order_by_used = True
            continue
        if token.value.upper() == 'LIMIT':
            order_by_used = False
            others['limit'] = True
            continue
    
    if distinct_used:
        if is_pwn(token):
            continue
        column_name = get_full_column_name(Identifier([token]), aliases, schema)
        others['distinct'].add(column_name)
        distinct_used = False

    if order_by_used:
        print(token, token.ttype, type(token))
        if token.ttype is tks.Punctuation and token.value == ',':
            if sqlparse.parse(str(TokenList(ord))):
                orderby_tokens.append(get_orderby_expression(ord, aliases, schema))
            ord = []
        else:
            ord.append(token)

  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
CAST Token.Name <class 'sqlparse.sql.Token'>
( Token.Punctuation <class 'sqlparse.sql.Token'>
SUM Token.Name <class 'sqlparse.sql.Token'>
( Token.Punctuation <class 'sqlparse.sql.Token'>
BenchMinor Token.Name <class 'sqlparse.sql.Token'>
) Token.Punctuation <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
AS Token.Keyword <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
REAL Token.Name.Builtin <class 'sqlparse.sql.Token'>
) Token.Punctuation <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
/ Token.Operator <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
2 Token.Literal.Number.Integer <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
DESC Token.Keyword.Order <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>


In [94]:
get_orderby_expression(ord, aliases, schema)

'cast(__sum__) / 2 desc'

In [83]:
for tkn in sqlparse.parse(str(TokenList(ord)))[0].tokens:
    if is_pwn(tkn):
        continue
    print(tkn, tkn.ttype, type(tkn))

CAST(SUM(BenchMinor) AS REAL) / 2 DESC None <class 'sqlparse.sql.Operation'>


In [85]:
tkn = sqlparse.parse(str(TokenList(ord)))[0].tokens[1]
tkn

<Operation 'CAST(S...' at 0x7F6C4A54AAD0>

In [97]:
for sub_token in tkn.tokens:
    print(sub_token, sub_token.ttype, type(sub_token))
    # get_full_column_name(sub_token, aliases, schema)
    

CAST(SUM(BenchMinor) AS REAL) None <class 'sqlparse.sql.Function'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
/ Token.Operator <class 'sqlparse.sql.Token'>
  Token.Text.Whitespace <class 'sqlparse.sql.Token'>
2 DESC None <class 'sqlparse.sql.Identifier'>


In [96]:
for sub_token2 in sub_token.tokens:
    print(sub_token2, sub_token2.ttype, type(sub_token2))

CAST None <class 'sqlparse.sql.Identifier'>
(SUM(BenchMinor) AS REAL) None <class 'sqlparse.sql.Parenthesis'>


In [87]:
get_full_column_name(sub_token, aliases, schema)

AttributeError: 'NoneType' object has no attribute 'lower'

In [82]:
get_orderby_expression(ord, aliases, schema)

AttributeError: 'NoneType' object has no attribute 'lower'

In [48]:
set(error_ids)

{('dev', 344, "'Token' object is not subscriptable"),
 ('dev', 376, "'Token' object is not subscriptable"),
 ('dev', 511, "'NoneType' object has no attribute 'lower'"),
 ('dev',
  728,
  "Alias trailposi already exists in the column mapping.\n{'trailposi': 'max(position)'}"),
 ('dev', 729, "'NoneType' object has no attribute 'lower'"),
 ('dev', 765, "'NoneType' object has no attribute 'lower'"),
 ('dev', 833, "'NoneType' object has no attribute 'lower'"),
 ('dev', 1190, "'NoneType' object has no attribute 'lower'"),
 ('dev', 1329, "'NoneType' object has no attribute 'lower'"),
 ('dev', 1482, "'NoneType' object has no attribute 'lower'"),
 ('dev', 1536, "'NoneType' object has no attribute 'lower'"),
 ('dev', 1768, "'NoneType' object has no attribute 'lower'"),
 ('dev', 2568, "'NoneType' object has no attribute 'lower'"),
 ('dev', 2573, "'NoneType' object has no attribute 'lower'"),
 ('dev', 2680, "'NoneType' object has no attribute 'lower'"),
 ('dev', 2682, "'NoneType' object has no att

In [28]:
x = nlp(sql_dev_temp[0])

In [30]:
x.similarity(x)

1.0

In [20]:
print(extract_selection(statement, aliases, schema))
print(extract_condition(statement, aliases, schema))
print(extract_aggregation(statement, aliases, schema))
print(extract_nested_setoperation(statement))
print(extract_others(statement, aliases, schema))

({'__movies.movie_title__'}, {('__movies.movie_title__', '<s>')})
({'__movies.movie_release_year__ = 2003', '__ratings.user_id__ = 2941'}, {'='})
(set(), set())
0
{'distinct': set(), 'order by': set(), 'limit': False}


In [18]:
df_dev

Unnamed: 0,db_id,sql
0,movie_platform,SELECT T2.movie_title FROM ratings AS T1 INNER...
1,movie_platform,SELECT COUNT(T1.user_id) FROM ratings AS T1 IN...
2,movie_platform,SELECT T2.movie_title FROM ratings AS T1 INNER...
3,movie_platform,SELECT T2.movie_title FROM ratings AS T1 INNER...
4,movie_platform,"SELECT T2.list_title, T1.user_avatar_image_url..."
...,...,...
4410,debit_card_specializing,"SELECT SUM(T1.Price) , SUM(IIF(T3.Date = '2012..."
4411,debit_card_specializing,SELECT T2.Description FROM transactions_1k AS ...
4412,debit_card_specializing,"SELECT T2.CustomerID, SUM(T2.Price / T2.Amount..."
4413,debit_card_specializing,SELECT T2.Country FROM transactions_1k AS T1 I...


In [None]:
group_count = {}
for sample in dev_samples+test_samples:
    if sample.db_id not in group_count:
        group_count[sample.db_id] = 0
    group_count[sample.db_id] += 1



In [4]:
samples = load_samples_spider_bird(proj_path / 'data' / f'bird_train.json')

In [43]:
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from src.pymodels import SQLResponse
   
prompt = PromptTemplate(
    template=Prompts.zero_shot_inference_bird,
    input_variables=['schema', 'input_query', 'evidence']
)
model_name = 'gpt-4o-mini'
model_openai = ChatOpenAI(
    model=model_name,
    temperature=0.0,
    logprobs=True,
    top_logprobs=5
)

# model = model_openai.with_structured_output(SQLResponse, include_raw=True)
chain = (prompt | model_openai)

In [None]:
eval_path = proj_path / 'experiments' / 'zero_shot' / 'bird'
if not eval_path.exists():
    eval_path.mkdir(parents=True)

# run zero-shot SQL generation
results = {}
iterator = tqdm(samples, total=len(samples))
for i, sample in enumerate(iterator):
    db_id = sample.db_id
    iterator.set_description(f"Processing {db_id} - {sample.sample_id}")
    schema = get_schema_str(
        schema=bird_tables[db_id].db_schema,
        foreign_keys=bird_tables[db_id].foreign_keys,
        primary_keys=bird_tables[db_id].primary_keys,
        col_explanation=all_descriptions[db_id]    
    )
    output = chain.invoke(input={
        'schema': schema,
        'input_query': sample.final.question,
        'evidence': sample.evidence
    })
    o = SQLResponse(**json.loads(output.content))
    usage = output.usage_metadata
    logprobs = output.response_metadata['logprobs']['content']
    results[sample.sample_id] = {
        'sample_id': sample.sample_id,
        'output': {
            'sql': o.full_sql_query,
            'rationale': o.rationale,
        },
        'usage': usage,
        'logprobs': logprobs
    }

Processing movie_platform - 8:   0%|          | 8/8731 [00:41<12:42:35,  5.25s/it]


KeyboardInterrupt: 

In [None]:
# detect token indices after `full_sql_query`
txt = ''
sql_tokens = []
start = False
for i, x in enumerate(logprobs):
    txt += x['token']
    if 'full_sql_query' in txt:
        if x['token'] == 'SELECT':
            start = True
            txt = ''

    if start:
        sql_tokens.append(x)

In [83]:
''.join([x['token'] for x in sql_tokens][:-1])

'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC;"\n'