In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
import sqlglot
import sqlglot.expressions as exp
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import DatabaseModel, QuestionSQL, SparcSample, SpiderSample, Description
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)
from src.parsing_sql import extract_all, Schema

bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)

In [26]:
all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)
skip = [622, 6916, 6917, 6930, 6967, 6987]
bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2)
# makesure the dev/test sql is not in the train sql

save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
save_samples_spider_bird(dev_samples+test_samples, proj_path / 'data' / 'bird_dev.json')
print(len(train_samples), len(dev_samples), len(test_samples))

100%|██████████| 10956/10956 [00:03<00:00, 2829.24it/s]

6535 2155 2260





In [None]:
df_train = pd.DataFrame({
    'db_id': [x.db_id for x in train_samples], 
    'sql': [x.final.sql for x in train_samples],
    'sample_id': [x.sample_id for x in train_samples]
})
df_dev = pd.DataFrame({
    'db_id': [x.db_id for x in dev_samples] + [x.db_id for x in test_samples],
    'sql': [x.final.sql for x in dev_samples] + [x.final.sql for x in test_samples],
    'sample_id': [x.sample_id for x in dev_samples] + [x.sample_id for x in test_samples]
})

In [4]:
# df_train.to_csv(proj_path / 'data' / 'bird_train.csv', index=False)
# df_dev.to_csv(proj_path / 'data' / 'bird_dev.csv', index=False)

df_train = pd.read_csv(proj_path / 'data' / 'bird_train.csv')
df_dev = pd.read_csv(proj_path / 'data' / 'bird_dev.csv')

In [8]:
from src.eval_utils import (
    partial_match
)
import sqlglot
import spacy
nlp_spacy = spacy.load('en_core_web_md')

In [17]:
db_ids = df_dev['db_id'].unique()
error_ids = []
results = defaultdict()
for k, db_id in enumerate(db_ids):
    sql_db_id_dev = df_dev.loc[df_dev['db_id'] == db_id, ['sql', 'sample_id']]
    sql_db_id_train = df_train.loc[df_train['db_id'] == db_id, ['sql', 'sample_id']]
    # spacy similarity
    sql_dev = sql_db_id_dev['sql'].apply(lambda x: nlp_spacy(x))
    sql_train = sql_db_id_train['sql'].apply(lambda x: nlp_spacy(x))
    semantic_sim = np.zeros((len(sql_dev), len(sql_train)))
    for i, sql_d in enumerate(sql_dev):
        for j, sql_t in enumerate(sql_train):
            semantic_sim[i, j] = sql_d.similarity(sql_t)

    # structural similarity
    schema = Schema(bird_tables[db_id].db_schema)
    
    for i, (_, row_d) in tqdm(enumerate(sql_db_id_dev.iterrows()), total=len(sql_db_id_dev), desc=f'[{k}] DB {db_id}'):
        sql_d = row_d['sql']
        sid_d = row_d['sample_id']
        struct_sim = np.zeros((len(sql_dev), len(sql_train)))
        try:
            parsed_sql_d = sqlglot.parse_one(sql_d)
            results_d = extract_all(parsed_sql_d, schema)
            assert len(results_d['sel']) > 0, 'No selection found'
        except Exception as e:
            error_ids.append(('dev', sid_d, str(e)))
            continue

        for j, (_, row_t) in enumerate(sql_db_id_train.iterrows()):
            sql_t = row_t['sql']
            sid_t = row_t['sample_id']
            try:    
                parsed_sql_t = sqlglot.parse_one(sql_t)
                results_t = extract_all(parsed_sql_t, schema)
                assert len(results_t['sel']) > 0, 'No selection found'
            except Exception as e:
                if ('tr', sid_t, str(e)) not in error_ids:
                    error_ids.append(('tr', sid_t, str(e)))
                continue

            sel_iou, *_ = partial_match(results_d['sel'], results_t['sel'])
            cond_iou, *_ = partial_match(results_d['cond'], results_t['cond'])
            agg_iou, *_ = partial_match(results_d['agg'], results_t['agg'])
            dis_iou, *_ = partial_match(results_d['distinct'], results_t['distinct'])
            ord_iou, *_ = partial_match(results_d['order by'], results_t['order by'])
            lim = int(results_d['limit'] == results_t['limit'])
            nested = int(results_d['nested'] == results_t['nested'])
            struct_sim[i, j] = sel_iou + cond_iou + agg_iou + dis_iou + ord_iou + (lim + nested) / 2

    results[db_id] = {
        'sem': semantic_sim,
        'struct': struct_sim
    }  

DB movie_platform: 100%|██████████| 67/67 [00:03<00:00, 19.85it/s]
DB book_publishing_company: 100%|██████████| 30/30 [00:00<00:00, 47.19it/s]
DB retail_complains: 100%|██████████| 68/68 [00:01<00:00, 46.45it/s]
DB movies_4: 100%|██████████| 64/64 [00:03<00:00, 17.03it/s]
DB codebase_comments: 100%|██████████| 50/50 [00:02<00:00, 23.00it/s]
DB trains: 100%|██████████| 16/16 [00:00<00:00, 71.72it/s]
DB movie: 100%|██████████| 19/19 [00:00<00:00, 282.73it/s]
DB social_media: 100%|██████████| 32/32 [00:00<00:00, 46.69it/s]
DB cs_semester: 100%|██████████| 46/46 [00:01<00:00, 24.71it/s]
DB computer_student: 100%|██████████| 29/29 [00:00<00:00, 42.45it/s]
DB talkingdata: 100%|██████████| 83/83 [00:03<00:00, 22.07it/s]
DB law_episode: 100%|██████████| 46/46 [00:01<00:00, 28.92it/s]
DB synthea: 100%|██████████| 74/74 [00:05<00:00, 13.50it/s]
DB car_retails: 100%|██████████| 51/51 [00:02<00:00, 23.14it/s]
DB restaurant: 100%|██████████| 47/47 [00:01<00:00, 30.85it/s]
DB soccer_2016: 100%|█████

In [18]:
import pickle 
with open(proj_path / 'data' / 'errors_bird.pkl', 'wb') as f:
    pickle.dump(error_ids, f)

with open(proj_path / 'data' / 'errors_bird.pkl', 'rb') as f:
    error_ids = pickle.load(f)

In [21]:
error_ids

[('tr', 12, 'too many values to unpack (expected 2)'),
 ('tr', 17, 'too many values to unpack (expected 2)'),
 ('tr', 21, 'too many values to unpack (expected 2)'),
 ('tr', 30, 'too many values to unpack (expected 2)'),
 ('tr', 38, 'too many values to unpack (expected 2)'),
 ('tr', 40, 'too many values to unpack (expected 2)'),
 ('tr', 44, 'too many values to unpack (expected 2)'),
 ('tr', 45, 'too many values to unpack (expected 2)'),
 ('tr', 48, 'too many values to unpack (expected 2)'),
 ('tr', 74, 'too many values to unpack (expected 2)'),
 ('dev', 151, 'too many values to unpack (expected 2)'),
 ('dev', 157, 'too many values to unpack (expected 2)'),
 ('dev', 158, 'too many values to unpack (expected 2)'),
 ('dev', 160, 'sequence item 0: expected str instance, tuple found'),
 ('dev', 162, 'too many values to unpack (expected 2)'),
 ('dev', 228, 'sequence item 0: expected str instance, tuple found'),
 ('dev', 229, 'sequence item 0: expected str instance, tuple found'),
 ('dev', 230

In [25]:
sql = """SELECT 'Date received' FROM callcenterlogs WHERE ser_time = ( SELECT MAX(ser_time) FROM callcenterlogs )"""
sqlglot.parse_one(sql)

Select(
  expressions=[
    Literal(this=Date received, is_string=True)],
  from=From(
    this=Table(
      this=Identifier(this=callcenterlogs, quoted=False))),
  where=Where(
    this=EQ(
      this=Column(
        this=Identifier(this=ser_time, quoted=False)),
      expression=Subquery(
        this=Select(
          expressions=[
            Max(
              this=Column(
                this=Identifier(this=ser_time, quoted=False)))],
          from=From(
            this=Table(
              this=Identifier(this=callcenterlogs, quoted=False))))))))

In [37]:
from src.parsing_sql import (
    extract_aliases,
    get_subqueries,
    extract_selection,
    extract_aggregation,
    extract_condition,
    extract_others, 
    _extract_conditions
)

In [54]:
sql = df_train.loc[df_train['sample_id'] == 17, 'sql'].values[0]
print(sql)
parsed_query = sqlglot.parse_one(sql)
# results = extract_all(parsed_query, schema)
aliases = extract_aliases(parsed_query)
subqueries = get_subqueries(parsed_query)
results = defaultdict(set)
nested = len(subqueries)

SELECT list_url FROM lists WHERE list_update_timestamp_utc LIKE '2012%' AND list_followers BETWEEN 1 AND 2 ORDER BY list_update_timestamp_utc DESC LIMIT 1


In [55]:
for query in subqueries:
    sel_cols, sel_types  = extract_selection(query, aliases, schema)
    conds, op_types = extract_condition(query, aliases, schema)
    agg_cols, agg_types  = extract_aggregation(query, aliases, schema)
    others = extract_others(query, aliases, schema)

ValueError: too many values to unpack (expected 2)

In [56]:
conditions = set()
operator_types = set()

for clause_name in ("where", "having"):
    clause = query.args.get(clause_name)
    if clause:
        break
clause

Where(
  this=And(
    this=Like(
      this=Column(
        this=Identifier(this=list_update_timestamp_utc, quoted=False)),
      expression=Literal(this=2012%, is_string=True)),
    expression=Between(
      this=Column(
        this=Identifier(this=list_followers, quoted=False)),
      low=Literal(this=1, is_string=False),
      high=Literal(this=2, is_string=False))))

In [57]:
clause.this

And(
  this=Like(
    this=Column(
      this=Identifier(this=list_update_timestamp_utc, quoted=False)),
    expression=Literal(this=2012%, is_string=True)),
  expression=Between(
    this=Column(
      this=Identifier(this=list_followers, quoted=False)),
    low=Literal(this=1, is_string=False),
    high=Literal(this=2, is_string=False)))

In [58]:
ops, conds = _extract_conditions(clause.this, aliases, schema)

ValueError: too many values to unpack (expected 2)

In [53]:
expr = clause.this

operations = []
conditions = []
left = expr.args.get('this')
right = expr.args.get('expression')
if left:
    ops, conds = _extract_conditions(left, aliases, schema)
    operations.extend(ops)
    conditions.extend(conds)
if right:
    ops, conds = _extract_conditions(right, aliases, schema)
    operations.extend(ops)
    conditions.extend(conds)

ValueError: too many values to unpack (expected 2)

In [50]:
right

Between(
  this=Column(
    this=Identifier(this=list_followers, quoted=False)),
  low=Literal(this=1, is_string=False),
  high=Literal(this=2, is_string=False))

In [51]:
type(exp.Between).__bases__

(type,)

In [43]:
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from src.pymodels import SQLResponse
   
prompt = PromptTemplate(
    template=Prompts.zero_shot_inference_bird,
    input_variables=['schema', 'input_query', 'evidence']
)
model_name = 'gpt-4o-mini'
model_openai = ChatOpenAI(
    model=model_name,
    temperature=0.0,
    logprobs=True,
    top_logprobs=5
)

# model = model_openai.with_structured_output(SQLResponse, include_raw=True)
chain = (prompt | model_openai)

In [None]:
eval_path = proj_path / 'experiments' / 'zero_shot' / 'bird'
if not eval_path.exists():
    eval_path.mkdir(parents=True)

# run zero-shot SQL generation
results = {}
iterator = tqdm(samples, total=len(samples))
for i, sample in enumerate(iterator):
    db_id = sample.db_id
    iterator.set_description(f"Processing {db_id} - {sample.sample_id}")
    schema = get_schema_str(
        schema=bird_tables[db_id].db_schema,
        foreign_keys=bird_tables[db_id].foreign_keys,
        primary_keys=bird_tables[db_id].primary_keys,
        col_explanation=all_descriptions[db_id]    
    )
    output = chain.invoke(input={
        'schema': schema,
        'input_query': sample.final.question,
        'evidence': sample.evidence
    })
    o = SQLResponse(**json.loads(output.content))
    usage = output.usage_metadata
    logprobs = output.response_metadata['logprobs']['content']
    results[sample.sample_id] = {
        'sample_id': sample.sample_id,
        'output': {
            'sql': o.full_sql_query,
            'rationale': o.rationale,
        },
        'usage': usage,
        'logprobs': logprobs
    }

Processing movie_platform - 8:   0%|          | 8/8731 [00:41<12:42:35,  5.25s/it]


KeyboardInterrupt: 

In [None]:
# detect token indices after `full_sql_query`
txt = ''
sql_tokens = []
start = False
for i, x in enumerate(logprobs):
    txt += x['token']
    if 'full_sql_query' in txt:
        if x['token'] == 'SELECT':
            start = True
            txt = ''

    if start:
        sql_tokens.append(x)

In [83]:
''.join([x['token'] for x in sql_tokens][:-1])

'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC;"\n'