In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import DatabaseModel, QuestionSQL, SparcSample, SpiderSample, Description
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)

bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)

In [3]:
all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)
skip = [622, 6916, 6917, 6930, 6967, 6987]
bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
# train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2)
# makesure the dev/test sql is not in the train sql

# save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
# save_samples_spider_bird(dev_samples, proj_path / 'data' / 'bird_dev.json')
# save_samples_spider_bird(test_samples, proj_path / 'data' / 'bird_test.json')
# print(len(train_samples), len(dev_samples), len(test_samples))

100%|██████████| 10956/10956 [00:03<00:00, 3058.51it/s]


In [6]:
bird_samples.keys()

dict_keys(['movie_platform', 'book_publishing_company', 'retail_complains', 'movies_4', 'codebase_comments', 'trains', 'movie', 'social_media', 'cs_semester', 'computer_student', 'talkingdata', 'law_episode', 'synthea', 'car_retails', 'restaurant', 'soccer_2016', 'music_tracker', 'world_development_indicators', 'movielens', 'superstore', 'shooting', 'genes', 'app_store', 'regional_sales', 'european_football_1', 'professional_basketball', 'shakespeare', 'cars', 'donor', 'video_games', 'authors', 'college_completion', 'public_review_platform', 'citeseer', 'simpson_episodes', 'student_loan', 'mental_health_survey', 'disney', 'legislator', 'olympics', 'address', 'beer_factory', 'sales', 'menu', 'shipping', 'language_corpus', 'airline', 'books', 'food_inspection_2', 'coinmarketcap', 'retail_world', 'retails', 'ice_hockey_draft', 'works_cycles', 'image_and_language', 'hockey', 'world', 'music_platform_2', 'university', 'sales_in_weather', 'mondial_geo', 'software_company', 'chicago_crime', '

In [None]:
# df_train = pd.DataFrame({
#     'db_id': [x.db_id for x in train_samples], 
#     'sql': [x.final.sql for x in train_samples],
#     'sample_id': [x.sample_id for x in train_samples]
# })
# df_dev = pd.DataFrame({
#     'db_id': [x.db_id for x in dev_samples] + [x.db_id for x in test_samples],
#     'sql': [x.final.sql for x in dev_samples] + [x.final.sql for x in test_samples],
#     'sample_id': [x.sample_id for x in dev_samples] + [x.sample_id for x in test_samples]
# })

# df_train.to_csv(proj_path / 'data' / 'bird_train.csv', index=False)
# df_dev.to_csv(proj_path / 'data' / 'bird_dev.csv', index=False)

# df_train = pd.read_csv(proj_path / 'data' / 'bird_train.csv')
# df_dev = pd.read_csv(proj_path / 'data' / 'bird_dev.csv')

In [34]:
from src.eval_utils import (
    partial_match
)
import sqlglot
import spacy
from src.parsing_sql import extract_all, Schema
nlp_spacy = spacy.load('en_core_web_md')

In [314]:
error_ids = []
results = defaultdict()
parsed = defaultdict(list)
for db_id, samples in bird_samples.items():
    schema = Schema(bird_tables[db_id].db_schema)
    iterator = tqdm(range(len(samples)), total=len(samples), desc=f"{db_id}")
    for i in iterator:
        sql_i = samples[i].final.sql
        spacy_i = nlp_spacy(sql_i)
        try:
            parsed_i = sqlglot.parse_one(sql_i)
            ei = extract_all(parsed_i, schema)
            assert len(ei['sel']) > 0, f'No selection found-{db_id}-{i}'
        except Exception as e:
            error_ids.append((db_id, i, str(e)))
            parsed[db_id].append(None)
            continue
        parsed[db_id].append((spacy_i, ei))

for db_id, samples in tqdm(bird_samples.items(), total=len(bird_samples), desc='Computing similarity'):
    semantic_sim = np.zeros((len(samples), len(samples)))
    struct_sim = np.zeros((len(samples), len(samples)))
    # iterator = tqdm(range(len(samples)), total=len(samples), desc=f"{db_id}")
    for i in range(len(samples)):
        if parsed[db_id][i] is None:
            continue
        spacy_i, ei = parsed[db_id][i]        
        for j in range(i+1, len(samples)):
            if parsed[db_id][j] is None:
                continue
            spacy_j, ej = parsed[db_id][j]

            sel_iou, *_ = partial_match(ei['sel'], ej['sel'])
            cond_iou, *_ = partial_match(ei['cond'], ej['cond'])
            agg_iou, *_ = partial_match(ei['agg'], ej['agg'])
            dis_iou, *_ = partial_match(ei['distinct'], ej['distinct'])
            ord_iou, *_ = partial_match(ei['order by'], ej['order by'])
            lim = int(ei['limit'] == ej['limit'])
            nested = int(ei['nested'] == ej['nested'])
            
            semantic_sim[i, j] = spacy_i.similarity(spacy_j)
            semantic_sim[j, i] = semantic_sim[i, j]
            struct_sim[i, j] = sel_iou + cond_iou + agg_iou + dis_iou + ord_iou + (lim + nested) / 2
            struct_sim[j, i] = struct_sim[i, j]

    results[db_id] = {
        'semantic': semantic_sim,
        'struct': struct_sim
    }

movie_platform: 100%|██████████| 167/167 [00:00<00:00, 242.52it/s]
book_publishing_company: 100%|██████████| 73/73 [00:00<00:00, 221.86it/s]
retail_complains: 100%|██████████| 168/168 [00:00<00:00, 208.79it/s]
movies_4: 100%|██████████| 158/158 [00:00<00:00, 202.67it/s]
codebase_comments: 100%|██████████| 123/123 [00:00<00:00, 237.19it/s]
trains: 100%|██████████| 40/40 [00:00<00:00, 197.41it/s]
movie: 100%|██████████| 46/46 [00:00<00:00, 192.40it/s]
social_media: 100%|██████████| 78/78 [00:00<00:00, 244.38it/s]
cs_semester: 100%|██████████| 113/113 [00:00<00:00, 181.12it/s]
computer_student: 100%|██████████| 72/72 [00:00<00:00, 199.57it/s]
talkingdata: 100%|██████████| 206/206 [00:00<00:00, 212.40it/s]
law_episode: 100%|██████████| 114/114 [00:00<00:00, 219.34it/s]
synthea: 100%|██████████| 185/185 [00:00<00:00, 205.85it/s]
car_retails: 100%|██████████| 126/126 [00:00<00:00, 227.06it/s]
restaurant: 100%|██████████| 117/117 [00:00<00:00, 267.84it/s]
soccer_2016: 100%|██████████| 258/258

In [107]:
from src.parsing_sql import (
    extract_aliases,
    extract_condition,
    get_subqueries,
    _extract_conditions,
    _format_expression,
    _format_condition_expression,
    OPERATOR_MAP
)

In [None]:
samples = bird_samples['video_games']
sql_i = [x.final.sql for x in samples if x.sample_id == 3364][0]
parsed_i = sqlglot.parse_one(sql_i)
ei = extract_all(parsed_i, schema)

In [305]:
sql_i = """SELECT DISTINCT T4.genre_name 
FROM game_platform AS T1 
INNER JOIN game_publisher AS T2 ON T1.game_publisher_id = T2.id 
INNER JOIN game AS T3 ON T2.game_id = T3.id 
INNER JOIN genre AS T4 ON T3.genre_id = T4.id 
WHERE T1.release_year NOT BETWEEN 2000 AND 2002 AND T3.platform_id = 1 AND T3.price < 50 
"""
# T1.release_year NOT BETWEEN 2000 AND 2002 AND T3.platform_id = 1 AND T3.price < 50 
parsed_i = sqlglot.parse_one(sql_i)

In [306]:
aliases = extract_aliases(parsed_i)
subqueries = get_subqueries(parsed_i)
for query in subqueries:
    break

In [311]:
conditions = set()
operator_types = set()

# conds, op_types = extract_condition(query, aliases, schema)
for clause_name in ("where", "having"):
    clause = query.args.get(clause_name)
    if clause:
        conds = _extract_conditions(
            clause.this, aliases, schema, operator_types
        )
        # conditions.update(conds)
        break

In [313]:
conds, operator_types

(['__game_platform.release_year__ not between 2000 and 2002',
  '__game.platform_id__ eq 1',
  '__game.price__ lt 50'],
 {'between', 'eq', 'lt', 'not'})

In [301]:
expr = clause.this
left = expr.args.get('this')
right = expr.args.get('expression')

In [302]:
left

And(
  this=Not(
    this=Between(
      this=Column(
        this=Identifier(this=release_year, quoted=False),
        table=Identifier(this=T1, quoted=False)),
      low=Literal(this=2000, is_string=False),
      high=Literal(this=2002, is_string=False))),
  expression=EQ(
    this=Column(
      this=Identifier(this=platform_id, quoted=False),
      table=Identifier(this=T3, quoted=False)),
    expression=Literal(this=1, is_string=False)))

In [303]:
left_cond = _extract_conditions(left, aliases, schema, operator_types) if left else ''
left_cond

'(__game_platform.release_year__ not between 2000 and 2002 and __game.platform_id__ eq 1)'

In [284]:
type(expr).__bases__

(sqlglot.expressions.Binary, sqlglot.expressions.Predicate)

In [282]:
cond_str = _format_expression(expr, aliases, schema, remove_alias=True)
cond_str

"__game.platform_id__ like '%1%'"

In [217]:

operations = []
conditions = []
left = expr.args.get('this')
right = expr.args.get('expression')
# ops, conds = _extract_conditions(left, aliases, schema)
_extract_conditions(left, aliases, schema, operator_types)

'__game_platform.release_year__ between 2000 and 2002'

In [216]:
_format_expression(expr.args.get('this'), aliases, schema, remove_alias=True)

'__game_platform.release_year__'

In [210]:
expr.args.get('low')

Literal(this=2000, is_string=False)

In [118]:
left2 = left.args.get('this')
left2

Column(
  this=Identifier(this=release_year, quoted=False),
  table=Identifier(this=T1, quoted=False))

In [315]:
import pickle 
with open(proj_path / 'data' / 'errors_bird.pkl', 'wb') as f:
    pickle.dump(error_ids, f)

# with open(proj_path / 'data' / 'errors_bird.pkl', 'rb') as f:
#     error_ids = pickle.load(f)

In [318]:
set(error_ids)

{('california_schools', 0, 'No selection found-california_schools-0'),
 ('california_schools', 1, 'No selection found-california_schools-1'),
 ('california_schools', 31, 'No selection found-california_schools-31'),
 ('disney', 1, 'No selection found-disney-1'),
 ('disney', 17, 'No selection found-disney-17'),
 ('disney', 42, 'No selection found-disney-42'),
 ('disney', 44, 'No selection found-disney-44'),
 ('disney', 78, 'No selection found-disney-78'),
 ('disney', 79, 'No selection found-disney-79'),
 ('disney', 102, 'No selection found-disney-102'),
 ('disney', 108, 'No selection found-disney-108'),
 ('disney', 112, 'No selection found-disney-112'),
 ('european_football_2',
  11,
  "Required keyword: 'this' missing for <class 'sqlglot.expressions.Datetime'>. Line 1, Col: 26.\n  SELECT DISTINCT DATETIME(\x1b[4m)\x1b[0m - T2.birthday age FROM Player_Attributes AS t1 INNER JOIN Player AS t2 ON t1.player_api_id = t2.pla"),
 ('food_inspection', 2, 'No selection found-food_inspection-2'),


In [25]:
sql = """SELECT 'Date received' FROM callcenterlogs WHERE ser_time = ( SELECT MAX(ser_time) FROM callcenterlogs )"""
sqlglot.parse_one(sql)

Select(
  expressions=[
    Literal(this=Date received, is_string=True)],
  from=From(
    this=Table(
      this=Identifier(this=callcenterlogs, quoted=False))),
  where=Where(
    this=EQ(
      this=Column(
        this=Identifier(this=ser_time, quoted=False)),
      expression=Subquery(
        this=Select(
          expressions=[
            Max(
              this=Column(
                this=Identifier(this=ser_time, quoted=False)))],
          from=From(
            this=Table(
              this=Identifier(this=callcenterlogs, quoted=False))))))))

In [37]:
from src.parsing_sql import (
    extract_aliases,
    get_subqueries,
    extract_selection,
    extract_aggregation,
    extract_condition,
    extract_others, 
    _extract_conditions
)

In [54]:
sql = df_train.loc[df_train['sample_id'] == 17, 'sql'].values[0]
print(sql)
parsed_query = sqlglot.parse_one(sql)
# results = extract_all(parsed_query, schema)
aliases = extract_aliases(parsed_query)
subqueries = get_subqueries(parsed_query)
results = defaultdict(set)
nested = len(subqueries)

SELECT list_url FROM lists WHERE list_update_timestamp_utc LIKE '2012%' AND list_followers BETWEEN 1 AND 2 ORDER BY list_update_timestamp_utc DESC LIMIT 1


In [55]:
for query in subqueries:
    sel_cols, sel_types  = extract_selection(query, aliases, schema)
    conds, op_types = extract_condition(query, aliases, schema)
    agg_cols, agg_types  = extract_aggregation(query, aliases, schema)
    others = extract_others(query, aliases, schema)

ValueError: too many values to unpack (expected 2)

In [56]:
conditions = set()
operator_types = set()

for clause_name in ("where", "having"):
    clause = query.args.get(clause_name)
    if clause:
        break
clause

Where(
  this=And(
    this=Like(
      this=Column(
        this=Identifier(this=list_update_timestamp_utc, quoted=False)),
      expression=Literal(this=2012%, is_string=True)),
    expression=Between(
      this=Column(
        this=Identifier(this=list_followers, quoted=False)),
      low=Literal(this=1, is_string=False),
      high=Literal(this=2, is_string=False))))

In [57]:
clause.this

And(
  this=Like(
    this=Column(
      this=Identifier(this=list_update_timestamp_utc, quoted=False)),
    expression=Literal(this=2012%, is_string=True)),
  expression=Between(
    this=Column(
      this=Identifier(this=list_followers, quoted=False)),
    low=Literal(this=1, is_string=False),
    high=Literal(this=2, is_string=False)))

In [58]:
ops, conds = _extract_conditions(clause.this, aliases, schema)

ValueError: too many values to unpack (expected 2)

In [53]:
expr = clause.this

operations = []
conditions = []
left = expr.args.get('this')
right = expr.args.get('expression')
if left:
    ops, conds = _extract_conditions(left, aliases, schema)
    operations.extend(ops)
    conditions.extend(conds)
if right:
    ops, conds = _extract_conditions(right, aliases, schema)
    operations.extend(ops)
    conditions.extend(conds)

ValueError: too many values to unpack (expected 2)

In [50]:
right

Between(
  this=Column(
    this=Identifier(this=list_followers, quoted=False)),
  low=Literal(this=1, is_string=False),
  high=Literal(this=2, is_string=False))

In [51]:
type(exp.Between).__bases__

(type,)

In [43]:
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from src.pymodels import SQLResponse
   
prompt = PromptTemplate(
    template=Prompts.zero_shot_inference_bird,
    input_variables=['schema', 'input_query', 'evidence']
)
model_name = 'gpt-4o-mini'
model_openai = ChatOpenAI(
    model=model_name,
    temperature=0.0,
    logprobs=True,
    top_logprobs=5
)

# model = model_openai.with_structured_output(SQLResponse, include_raw=True)
chain = (prompt | model_openai)

In [None]:
eval_path = proj_path / 'experiments' / 'zero_shot' / 'bird'
if not eval_path.exists():
    eval_path.mkdir(parents=True)

# run zero-shot SQL generation
results = {}
iterator = tqdm(samples, total=len(samples))
for i, sample in enumerate(iterator):
    db_id = sample.db_id
    iterator.set_description(f"Processing {db_id} - {sample.sample_id}")
    schema = get_schema_str(
        schema=bird_tables[db_id].db_schema,
        foreign_keys=bird_tables[db_id].foreign_keys,
        primary_keys=bird_tables[db_id].primary_keys,
        col_explanation=all_descriptions[db_id]    
    )
    output = chain.invoke(input={
        'schema': schema,
        'input_query': sample.final.question,
        'evidence': sample.evidence
    })
    o = SQLResponse(**json.loads(output.content))
    usage = output.usage_metadata
    logprobs = output.response_metadata['logprobs']['content']
    results[sample.sample_id] = {
        'sample_id': sample.sample_id,
        'output': {
            'sql': o.full_sql_query,
            'rationale': o.rationale,
        },
        'usage': usage,
        'logprobs': logprobs
    }

Processing movie_platform - 8:   0%|          | 8/8731 [00:41<12:42:35,  5.25s/it]


KeyboardInterrupt: 

In [None]:
# detect token indices after `full_sql_query`
txt = ''
sql_tokens = []
start = False
for i, x in enumerate(logprobs):
    txt += x['token']
    if 'full_sql_query' in txt:
        if x['token'] == 'SELECT':
            start = True
            txt = ''

    if start:
        sql_tokens.append(x)

In [83]:
''.join([x['token'] for x in sql_tokens][:-1])

'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC;"\n'