In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import matplotlib.pyplot as plt
import seaborn as sns


import json
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate


_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import (
    DatabaseModel, 
    SpiderSample, 
    BirdSample, 
    BODescription,
    SQLResponse
)
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)

from src.parsing_sql import Schema, extract_all
from src.eval_utils import get_complexity, result_eq, check_if_exists_orderby
from run_bo_sql import get_vector_store
from copy import deepcopy
bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)
train_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_train.json')
dev_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_dev.json')
test_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_test.json')

In [3]:
# spider_path = proj_path / 'data' / 'spider'
# tables, train_data, dev_data = load_raw_data(spider_path, load_test=False)

# with (proj_path / 'data' / 'description.json').open() as f:
#     all_descriptions = json.load(f)
# seed = 42
# all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)

# with open(proj_path / 'data' / 'bird_skip.txt') as f:
#     skip = [int(line.strip()) for line in f]

# bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
# train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2, seed=seed)

# save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
# save_samples_spider_bird(dev_samples, proj_path / 'data' / 'bird_dev.json')
# save_samples_spider_bird(test_samples, proj_path / 'data' / 'bird_test.json')
# print(len(train_samples), len(dev_samples), len(test_samples))

In [88]:
# experiment_folder = proj_path / 'experiments' / 'bird'
# prediction_path = experiment_folder / 'predictions' / 'create_bo'
# tables = bird_tables
# bos = []
# for p in prediction_path.glob('bird_train_bo_*.json'):
#     with p.open() as f:
#         bos = json.load(f)

#     db_id = p.stem.split('_', 3)[-1]
#     schema = Schema(tables[db_id].db_schema)
#     for bo in bos:
#         output = extract_all(bo['gold_sql'], schema)
#         bo['gold_complexity'] = get_complexity(output)
    
#     with p.open('w') as f:
#         json.dump(bos, f, indent=4)

# bos = {}
# for p in prediction_path.glob('bird_train_bo_*.json'):
#     db_id = p.stem.split('_', 3)[-1]
#     with p.open() as f:
#         bos[db_id] = json.load(f)

# with (experiment_folder / 'predictions' / 'create_bo' / f'final_bird_train_bo.json').open('w') as f:
#     json.dump(bos, f, indent=4)

In [6]:
from run_bo_sql import Sampler
# TODO: run spider 4567
ds = 'bird'
task = 'valid_bo'
experiment_folder = proj_path / 'experiments' / ds
prediction_path = experiment_folder / 'predictions' / task
eval_path = experiment_folder / 'evals' / task

# dev_samples = load_samples_spider_bird(proj_path / 'data' / f'{ds}_dev.json')
pred_res = defaultdict(dict)  # db_id -> train_bo -> list[dict]
for p in prediction_path.glob(f'{ds}_dev_*.json'):
    name = p.stem.split('_', 2)[-1]
    db_id, idx = name.split('-')
    with p.open() as f:
        res = json.load(f)
    
    train_bo_id = r['retrieved']
    if not pred_res[db_id].get(train_bo_id):
        pred_res[db_id][train_bo_id] = []
    for r in res:
        pred_res[db_id][r['retrieved']] = r

save_path = prediction_path / f'final_{ds}_dev.jsonl'

In [None]:
res

[{'sample_id': 997,
  'gold_sql': 'SELECT T1.p_id FROM taughtBy AS T1 INNER JOIN person AS T2 ON T1.p_id = T2.p_id WHERE T2.professor = 1 GROUP BY T1.p_id HAVING COUNT(DISTINCT T1.course_id) > 3',
  'retrieved': 1022,
  'rationale': ["Identify the relevant tables: 'taughtby' for teaching assignments and 'course' for course details.",
   "Join the 'taughtby' table with the 'course' table on 'course_id' to access course information for each professor.",
   "Group the results by 'p_id' to aggregate the number of courses taught by each professor.",
   'Use the HAVING clause to filter out professors who teach more than 3 courses.'],
  'pred_sql': 'SELECT taughtby.p_id FROM taughtby INNER JOIN course ON taughtby.course_id = course.course_id GROUP BY taughtby.p_id HAVING COUNT(taughtby.course_id) > 3;',
  'token_usage': {'tokens': 912, 'cost': 0.00019755}},
 {'sample_id': 996,
  'gold_sql': 'SELECT T1.courseLevel FROM course AS T1 INNER JOIN taughtBy AS T2 ON T1.course_id = T2.course_id GROUP

In [171]:
db_ids = list(bos.keys())
partial_db_ids = {}
n = 20
for i in range(30):
    if db_ids[i*n:(i+1)*n]:
        partial_db_ids[i] = db_ids[i*n:(i+1)*n]
print(partial_db_ids.keys())

with open(experiment_folder / f'partial_{ds}_db_ids.json', 'w') as f:
    json.dump(partial_db_ids, f, indent=4)

dict_keys([0, 1, 2, 3, 4, 5, 6, 7])


In [172]:
with open(experiment_folder / f'partial_{ds}_db_ids.json') as f:
    partial_db_ids = json.load(f)

sampler = Sampler(bos)

In [None]:
sh ./scripts/valid_bo/valid_bo_bird_

In [174]:
from itertools import product, islice

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

sampled_ids = {}
for db_id_group in partial_db_ids:
    sampled_ids[str(db_id_group)] = defaultdict()
    for db_id in partial_db_ids[str(db_id_group)]:
        x_samples = list(filter(lambda x: x.db_id == db_id, dev_samples))
        for idx_bos, train_bos in enumerate(sampler.sample(db_id, 3, 50, rt_idx=False)):
            # print(f'{db_id}-{idx_bos} :', f'{len(train_bos)}', f'{len(list(product(train_bos, x_samples)))}')
            sampled_ids[str(db_id_group)][f'{db_id}-{idx_bos}'] = {
                'train_bos': train_bos,
                'n_iter': len(list(product(train_bos, x_samples))), 
                'total_bos_in_batch': len(train_bos)
            }

with (experiment_folder / f'partial_{ds}_batch.json').open('w') as f:
    json.dump(sampled_ids, f, indent=4)

In [175]:
for db_id_group in partial_db_ids:
    print(len(sampled_ids[str(db_id_group)]))
    niters = [x['n_iter'] for x in sampled_ids[str(db_id_group)].values()]
    print(f'n_iter: {sum(niters)}, iter per file: {np.mean(niters):.2f}')

82
n_iter: 8394, iter per file: 102.37
50
n_iter: 4583, iter per file: 91.66
62
n_iter: 6707, iter per file: 108.18
75
n_iter: 10838, iter per file: 144.51
68
n_iter: 6246, iter per file: 91.85
64
n_iter: 6218, iter per file: 97.16
70
n_iter: 7298, iter per file: 104.26
70
n_iter: 8072, iter per file: 115.31


In [6]:
df = []
for db_id, bs in bos.items():
    for b in bs:
        res = {'db_id': db_id, 'gold_complexity': b['gold_complexity']}
        df.append(res)

df = pd.DataFrame(df)

In [115]:
from itertools import pairwise

def _format_interval(x: pd.Interval):
    return pd.Interval(
        left=int(np.floor(x.left)), 
        right=int(np.floor(x.right)),
        closed=x.closed
    )

def _get_categories(s: pd.Series):
    tiles = [0, 0.2, 0.4, 0.6, 0.8, 1]
    df = pd.qcut(s, q=tiles, duplicates='drop')
    return df

def _get_df_from_bos(bos):
    df = []
    for db_id, bs in bos.items():
        for b in bs:
            res = {'db_id': db_id}
            res.update(b)
            df.append(res)
    df = pd.DataFrame(df)
    df_cates = df.groupby('db_id')['gold_complexity'].apply(_get_categories)
    df_cates = df_cates.rename('category').apply(_format_interval)
    df = df.merge(df_cates.reset_index('db_id', drop=True), left_index=True, right_index=True)
    return df

In [299]:
# create a sampler with gold_complexity
from typing import Iterator
class Sampler():
    def __init__(self, bos: dict[str, list[dict]]):
        self.bos = bos
        self.df = _get_df_from_bos(bos)

    def _get_sample_batch(self, db_id: str, n_sample: int=1, n_stop: int=20):
        sampled = []
        sample_batch = []
        n_sampled = 0
        df_db_id = self.df.loc[(self.df['db_id'] == db_id) & ~self.df['sample_id'].isin(sampled)]
        while df_db_id['category'].nunique() > 0 and n_sampled < n_stop:
            groupby_statement = df_db_id.groupby('category')['sample_id']
            # n_minima, n_maxima = groupby_statement.size().agg(['min', 'max']).tolist()
            # n = min([n_sample, n_minima, n_maxima])  # actual sample size    
            sample_ids = groupby_statement.apply(lambda x: x.sample(min(len(x), n_sample))).tolist()
            sampled.extend(sample_ids)
            sample_batch.append(sample_ids)
            n_sampled += len(sample_ids)
            df_db_id = self.df.loc[(self.df['db_id'] == db_id) & ~self.df['sample_id'].isin(sampled)]

        return sample_batch
    
    def sample(self, db_id: str, n_sample: int=1, n_stop: int=20) -> Iterator:
        sample_batch = self._get_sample_batch(db_id, n_sample, n_stop)
        for sample_ids in sample_batch:
            s = self.df.loc[(self.df['db_id'] == db_id) & self.df['sample_id'].isin(sample_ids)]
            s = s.to_dict(orient='records')
            yield s

In [323]:
sampler = Sampler(bos)

for train_bos in sampler.sample('address', n_sample=1, n_stop=20):
    print(len(train_bos))

[autoreload of run_bo_sql failed: Traceback (most recent call last):
  File "/home/simonjisu/code/BusinessObjects/.venv/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/simonjisu/code/BusinessObjects/.venv/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "/home/simonjisu/code/BusinessObjects/.venv/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/home/simonjisu/code/BusinessObjects/.venv/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 365, in update_class
    update_instances(old, new)
  File "/home/simonjisu/code/BusinessObjects/.venv/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 319, in update_instances
    refs = gc.get_referrers(old)
           ^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
]


5
5
5
5


In [None]:
samples_by_db_id = defaultdict(list)
for sample in samples:
    samples_by_db_id[sample.db_id].append(sample)

sampler = Sampler(bos)

for db_id, samples in samples_by_db_id.items():

15

In [319]:
dev_samples[:3]

[BirdSample(sample_id=165, db_id='movie_platform', final=QuestionSQL(question='What is the name of the movie that was rated recently by user 57756708?', sql='SELECT T2.movie_title FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE T1.user_id = 57756708 ORDER BY T1.rating_timestamp_utc DESC LIMIT 1', source_tables=['movies', 'ratings']), evidence='user 57756708 refers to user_id = 57756708; rated recently refers to MAX(rating_timestamp_utc)', bo=None),
 BirdSample(sample_id=137, db_id='movie_platform', final=QuestionSQL(question='For all the users who gave "A Shot in the Dark" a rating, how many percent of them is a paying subscriber?', sql="SELECT CAST(SUM(CASE WHEN T1.user_has_payment_method = 1 THEN 1 ELSE 0 END) AS REAL) * 100 / COUNT(*) FROM ratings AS T1 INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id INNER JOIN lists_users AS T3 ON T1.user_id = T3.user_id WHERE T2.movie_title = 'A Shot in the Dark'", source_tables=['movies', 'ratings', 'lists_users

In [316]:
json.dumps({'description': bo['ba'], 'virtual_table': bo['vt']}, indent=2)

'{\n  "description": "The virtual table counts the number of cities associated with a specific congress representative based on their first and last names, while also filtering for cities that have a certain number of employees. It joins the \'congress\' table with the \'state\' table to match the state abbreviation and then joins with the \'zip_data\' table to access city information.",\n  "virtual_table": "SELECT COUNT(zip_data.city) FROM congress INNER JOIN state AS T2 ON T1.abbreviation = T2.abbreviation INNER JOIN zip_data AS T3 ON T2.abbreviation = T3.state WHERE congress.first_name = \'[placeholder-type:string]\' AND congress.last_name = \'[placeholder-type:string]\' AND zip_data.employees = [placeholder-type:numeric]"\n}'

In [None]:
ds = 'bird'
task = 'zero_shot_hint'
typ = 'dev'
experiment_folder = proj_path / 'experiments' / ds
prediction_path = experiment_folder / 'predictions' / task
eval_path = experiment_folder / 'evals' / task

# file_name = f'{ds}_{typ}_parsed.pkl'
# with (eval_path / file_name).open('rb') as f:
#     target_parsed = pickle.load(f)

In [308]:
prediction_path.parent.parent

PosixPath('/home/simonjisu/code/BusinessObjects/experiments/bird')

In [304]:
bos['address'][:4]

[{'sample_id': 5156,
  'vt': "SELECT area_code.area_code, country.county FROM area_code INNER JOIN country AS T2 ON T1.zip_code = T2.zip_code INNER JOIN zip_data AS T3 ON T1.zip_code = T3.zip_code WHERE zip_data.city = '[placeholder-type:string]'",
  'ba': "The virtual table provides the area code and county information for a specific city based on its zip code. It combines data from the 'area_code', 'country', and 'zip_data' tables, filtering results to match the specified city name.",
  'gold_complexity': 10,
  'gold_sql': "SELECT T1.area_code, T2.county FROM area_code AS T1 INNER JOIN country AS T2 ON T1.zip_code = T2.zip_code INNER JOIN zip_data AS T3 ON T1.zip_code = T3.zip_code WHERE T3.city = 'Savoy'"},
 {'sample_id': 5211,
  'vt': 'SELECT alias.alias FROM alias INNER JOIN zip_data AS T2 ON T1.zip_code = T2.zip_code WHERE zip_data.population_2020 = (SELECT MAX(zip_data.population_2020) FROM zip_data)',
  'ba': "The virtual table retrieves the aliases of cities from the 'alias' t

In [4]:
from pydantic import BaseModel
from langchain_openai import ChatOpenAI

from langchain_community.callbacks.manager import get_openai_callback

class Out(BaseModel):
    response: str

llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    stream_usage=True,
)
model = llm.with_structured_output(Out)


with get_openai_callback() as cb:
    result = model.invoke("Tell me a joke with JSON format")
    print(cb)

Tokens Used: 85
	Prompt Tokens: 51
	Completion Tokens: 34
Successful Requests: 1
Total Cost (USD): $2.805e-05


In [63]:
samples_by_db_id = defaultdict(list)
for sample in train_samples:
    samples_by_db_id[sample.db_id].append(sample)

x = []
for db_id, samples in samples_by_db_id.items():
    x.append(len(samples))

print(np.mean(x), np.std(x), np.min(x), np.max(x))

80.26582278481013 46.229123611557306 11 280


In [64]:
samples_by_db_id = defaultdict(list)
for sample in dev_samples:
    samples_by_db_id[sample.db_id].append(sample)

x = []
for db_id, samples in samples_by_db_id.items():
    x.append(len(samples))

print(np.mean(x), np.std(x), np.min(x), np.max(x))

26.468354430379748 15.462355628942769 3 93


In [119]:
# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_parsed.pkl', 'rb') as f:
#     train_parsed = pickle.load(f)

# # prediction parsed
# with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'rb') as f:
#     dev_parsed = pickle.load(f)

In [17]:
eval_path = proj_path / 'experiments' / 'bird' / 'evals' / 'zero_shot'

df = []
for p in eval_path.glob('bird_dev_*.json'):
    with p.open() as f:
        for line in f:
            eval_data = json.loads(line)
            df.append(eval_data)

df = pd.DataFrame(df)
df.to_csv(eval_path / 'bird_dev.csv', index=False)

In [16]:
df['gold_complexity'].agg(['mean', 'std', 'min', 'max', 'median'])

mean      0.450361
std       0.055482
min       0.318118
max       0.726155
median    0.446118
Name: gold_complexity, dtype: float64

In [31]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'create_bo'
bos = defaultdict(list)
for p in prediction_path.glob('bird_train_bo_*.json'):
    with p.open() as f:
        temp = json.load(f)
    
    bos[p.stem.split('_', 3)[-1]] = temp

# with (prediction_path / 'final_bird_train_bo.json').open('w') as f:
#     json.dump(bos, f, indent=4)

In [178]:
vector_store = get_vector_store({'address': bos['address'][:10]})

In [179]:
[b['sample_id'] for b in bos['address'][:10]]

[5156, 5211, 5227, 5091, 5152, 5128, 5200, 5119, 5194, 5141]

In [161]:
bos['address'][:10][0]

{'sample_id': 5156,
 'vt': "SELECT area_code.area_code, country.county FROM area_code INNER JOIN country AS T2 ON T1.zip_code = T2.zip_code INNER JOIN zip_data AS T3 ON T1.zip_code = T3.zip_code WHERE zip_data.city = '[placeholder-type:string]'",
 'ba': "The virtual table provides the area code and county information for a specific city based on its zip code. It combines data from the 'area_code', 'country', and 'zip_data' tables, filtering results to match the specified city name."}

In [180]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
base_retriever = vector_store.as_retriever(
    search_type='similarity_score_threshold', 
    search_kwargs={
        'k': 3,
        'score_threshold': 0.3, 'filter': {'sample_id': {'$nin': []}}
    }
)

# 'lambda_mult': 0.5  'score_threshold': 0.0
# 'filter': {'sample_id': {'$in': [5156]}}}
model = HuggingFaceCrossEncoder(model_name='cross-encoder/ms-marco-MiniLM-L-6-v2')
compressor = CrossEncoderReranker(model=model, top_n=1)
retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=base_retriever
)

In [181]:
q = 'what is the aliases of cities along with their elevation?'
x = base_retriever.invoke(q)
x

No relevant docs were retrieved using the relevance score threshold 0.3


[]

In [182]:
x = vector_store.similarity_search_with_relevance_scores(
    q, k=2, filter={'sample_id': {'$nin': [5152, 5211, 5194]}})
x
# similarity_search_with_relevance_scores

[]

In [148]:
docs_and_similarities = [
    (doc, similarity)
    for doc, similarity in x
    if similarity >= 0.5
]
docs_and_similarities

[(Document(metadata={'sample_id': 5152, 'db_id': 'address', 'vt': 'SELECT alias.alias, zip_data.elevation FROM alias INNER JOIN zip_data AS T2 ON T1.zip_code = T2.zip_code WHERE alias.zip_code = [placeholder-type:numeric]'}, page_content="The virtual table describes the aliases of cities along with their elevation from the 'zip_data' table. The query joins the 'alias' table with the 'zip_data' table based on the zip code, filtering for a specific zip code using a placeholder for numeric values."),
  0.7825041385389271),
 (Document(metadata={'sample_id': 5211, 'db_id': 'address', 'vt': 'SELECT alias.alias FROM alias INNER JOIN zip_data AS T2 ON T1.zip_code = T2.zip_code WHERE zip_data.population_2020 = (SELECT MAX(zip_data.population_2020) FROM zip_data)'}, page_content="The virtual table retrieves the aliases of cities from the 'alias' table that correspond to the zip codes with the highest population recorded in 2020 from the 'zip_data' table. The query uses an inner join to connect t

In [125]:
compressor([y[0] for y in x])

TypeError: 'CrossEncoderReranker' object is not callable

In [20]:
ContextualCompressionRetriever.__bases__

(langchain_core.retrievers.BaseRetriever,)

In [71]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'zero_shot'
predictions = []
for p in prediction_path.glob('bird_dev_*.json'):
    with open(p) as f:
        pred = json.load(f)
        new_pred = []
        for x in pred:
            x.pop('rationale')
            new_pred.append(x)
        predictions.extend(new_pred)

with open(prediction_path / 'final_bird_dev.jsonl', 'w') as f:
    for p in predictions:
        f.write(json.dumps(p) + '\n')

len(predictions)

2091

In [69]:
predictions[0]

{'sample_id': 8773,
 'db_id': 'food_inspection',
 'gold_sql': "SELECT COUNT(owner_state) FROM businesses WHERE owner_state = 'CA'",
 'pred_sql': "SELECT COUNT(DISTINCT owner_name) AS owner_count FROM businesses WHERE owner_state = 'California';"}

In [206]:
# typ = 'test'
# samples = test_samples
# predictions = []
# with open(prediction_path / f'final_bird_{typ}.jsonl', 'r') as f:
#     preds = f.readlines()
#     for p in preds:
#         pred = json.loads(p)
#         found = False
#         for sample in samples:
#             if sample.sample_id == pred['sample_id']:
#                 pred['gold_sql'] = sample.final.sql
#                 found = True
#                 break
#         if not found:
#             raise ValueError(f"Sample ID {pred['sample_id']} not found")
        
#         predictions.append(pred)

# with open(prediction_path / f'final_bird_{typ}.jsonl', 'w') as f:
#     for p in predictions:
#         f.write(json.dumps(p) + '\n')

In [3]:
import sqlglot
import sqlglot.expressions as exp
from src.eval_utils import result_eq, check_if_exists_orderby

def get_pred_results(
        proj_path: Path,
        predictions: list[dict],
        tables: dict[str, DatabaseModel],
        ds: str = 'bird' # spider or bird
    ):
    output_results = []
    error_infos = {
        'pred_exec': [],
        'gold_exec': [],
        'python_script': [],
        'result': []
    }
    predictions_by_db_id = defaultdict(list)
    for pred in predictions:
        predictions_by_db_id[pred['db_id']].append(pred)
    
    for db_id, preds in predictions_by_db_id.items():
        schema = Schema(tables[db_id].db_schema)
        if ds == 'bird':
            try:
                database = SqliteDatabase(
                    db_file=str(proj_path / 'data' / ds / 'train' / 'train_databases' / db_id / f'{db_id}.sqlite'),
                    foreign_keys=tables[db_id].foreign_keys
                )
            except:
                database = SqliteDatabase(
                    db_file=str(proj_path / 'data' / ds / 'dev' / 'dev_databases' / db_id / f'{db_id}.sqlite'),
                    foreign_keys=tables[db_id].foreign_keys
                )
        else:
            database = SqliteDatabase(
                db_file=str(proj_path / 'data' / 'spider' / 'database' / db_id / f'{db_id}.sqlite'), 
                foreign_keys=tables[db_id].foreign_keys
            )
        iterator = tqdm(preds, total=len(preds))
        for pred in iterator:
            iterator.set_description(f'{db_id} | pred_exec: {len(error_infos["pred_exec"])} | gold_exec: {len(error_infos["gold_exec"])} | python_script: {len(error_infos["python_script"])} | result: {len(error_infos["result"])}')

            pred_sql = pred['pred_sql'] 
            gold_sql = pred['gold_sql']
            
            error_info = ''
            try:
                pred_result = database.execute(pred_sql, rt_pandas=False)
            except Exception as e:
                pred_result = []
                error_infos['pred_exec'].append((db_id, pred['sample_id']))
                error_info = 'Predction Execution Error:' + str(e)
                score = 0

            try:
                gold_result = database.execute(gold_sql, rt_pandas=False)
            except Exception as e:
                error_infos['gold_exec'].append((db_id, pred['sample_id']))
                error_info = 'Gold Execution Error:' + str(e)
            
            if 'Gold Execution Error' in error_info:
                continue
            elif 'Predction Execution Error' in error_info:
                output_results.append(
                    {
                        'sample_id': pred['sample_id'], 
                        'db_id': db_id,
                        'score': score,
                        'gold_sql': gold_sql,
                        'pred_sql': pred_sql,
                    }
                )
                continue
            else:
                exists_orderby = check_if_exists_orderby(gold_sql)
                
                try:
                    score = int(result_eq(pred_result, gold_result, order_matters=exists_orderby))
                except Exception as e:
                    print(f"An error occurred: {e}")
                    score = 0
                    error_info = 'Python Script Error:' + str(e)
                    error_infos['python_script'].append((db_id, pred['sample_id']))

                if score == 0 and error_info == '':
                    error_info = 'Result not equal'
                    error_infos['result'].append((db_id, pred['sample_id']))
                output_results.append(
                    {
                        'sample_id': pred['sample_id'], 
                        'db_id': db_id,
                        'score': score,
                        'gold_sql': gold_sql,
                        'pred_sql': pred_sql,
                    }
                )

    return output_results, error_infos


In [4]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'zero_shot'

def load_predictions(prediction_path: Path, filename: str):
    predictions = []
    with open(prediction_path / filename, 'r') as f:
        preds = f.readlines()
        for p in preds:
            predictions.append(json.loads(p))
    return predictions
predictions = []
with open(prediction_path / 'final_bird_dev.jsonl', 'r') as f:
    preds = f.readlines()
    for p in preds:
        predictions.append(json.loads(p))

# output_results, error_infos = get_pred_results(proj_path, predictions, tables, ds='bird')

In [59]:
from run_evaluation import get_prediction_parsed_sql, get_target_parsed_sql
from src.eval_utils import get_complexity, get_all_partial_score

In [None]:
target_parsed, error_ids = get_target_parsed_sql([x for x in dev_samples if x.sample_id in [9503, 9468, 9494]], tables=bird_tables)
pred_parsed, error_ids = get_prediction_parsed_sql(predictions[:3], tables=bird_tables)

california_schools: 100%|██████████| 3/3 [00:00<00:00, 456.88it/s]


california_schools: 100%|██████████| 3/3 [00:00<00:00, 444.52it/s]


In [57]:
db_id = 'california_schools'

In [60]:
for pred in predictions[:3]:
    sample_id = pred['sample_id']
    target_o = target_parsed[db_id][sample_id]
    pred_o = pred_parsed[db_id][sample_id]
    if pred_o:
        _, all_score = get_all_partial_score(pred_o, target_o)

        print(f"Sample ID: {sample_id}")
        print(f'f1={all_score["overall"]} | {all_score["structural"]:.4f} | {all_score["semantic"]:.4f}')

Sample ID: 9503
f1=0.7798 | 0.7657 | 0.8017
Sample ID: 9468
f1=0.4271 | 0.4170 | 0.5096


In [26]:
predictions[2]

{'sample_id': 9494,
 'rationale': ['Identify the relevant table: frpm contains the enrollment data.',
  "Determine the columns needed: 'enrollment (ages 5-17)' for the number of students, 'academic year' to filter by the specific year, and 'district type' to specify the type of school.",
  'Filter the records for the academic year 2014-2015.',
  "Filter for the district type 'State Special Schools'.",
  "Filter for the county name 'Fremont'.",
  'Use SUM() to aggregate the total enrollment for the specified filters.'],
 'pred_sql': "SELECT SUM(enrollment (ages 5-17)) AS total_enrollment\nFROM frpm\nWHERE academic year = '2014-2015' AND district type = 'State Special Schools' AND county name = 'Fremont';",
 'db_id': 'california_schools',
 'gold_sql': 'SELECT T1."Enrollment (Ages 5-17)" FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.EdOpsCode = \'SSS\' AND T2.City = \'Fremont\' AND T1."Academic Year" BETWEEN 2014 AND 2015'}

In [35]:
sql = """
SELECT SUM("enrollment (ages 5-17)") AS total_enrollment
FROM frpm
WHERE "academic year" = '2014-2015' AND "district type" = 'State Special Schools' AND "county name" = 'Fremont';
"""
x = {'sample_id': 9494,
 'rationale': ['Identify the relevant table: frpm contains the enrollment data.',
  "Determine the columns needed: 'enrollment (ages 5-17)' for the number of students, 'academic year' to filter by the specific year, and 'district type' to specify the type of school.",
  'Filter the records for the academic year 2014-2015.',
  "Filter for the district type 'State Special Schools'.",
  "Filter for the county name 'Fremont'.",
  'Use SUM() to aggregate the total enrollment for the specified filters.'],
 'pred_sql': sql,
 'db_id': 'california_schools',
 'gold_sql': 'SELECT T1."Enrollment (Ages 5-17)" FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.EdOpsCode = \'SSS\' AND T2.City = \'Fremont\' AND T1."Academic Year" BETWEEN 2014 AND 2015'}

parsed, error_ids = get_prediction_parsed_sql([x], tables=bird_tables)

california_schools: 100%|██████████| 1/1 [00:00<00:00, 465.31it/s]


In [33]:
db_id = 'california_schools'
tables = bird_tables
db_schema = get_schema_str(
    schema=tables[db_id].db_schema, 
    foreign_keys=tables[db_id].foreign_keys,
    col_explanation=tables[db_id].col_explanation
)

In [29]:
database = SqliteDatabase(
    db_file=str(proj_path / 'data' / 'bird' / 'dev' / 'dev_databases' / 'california_schools' / 'california_schools.sqlite'),
    foreign_keys=bird_tables['california_schools'].foreign_keys
)

In [30]:
database.execute("""
SELECT "enrollment (ages 5-17)"
FROM frpm
LIMIT 1
""")

Unnamed: 0,Enrollment (Ages 5-17)
0,1070.0


In [28]:
error_ids

[('california_schools',
  9494,
  'Invalid expression / Unexpected token. Line 4, Col: 19.\n  \nSELECT SUM("enrollment (ages 5-17)") AS total_enrollment\nFROM frpm\nWHERE academic \x1b[4myear\x1b[0m = \'2014-2015\' AND district type = \'State Special Schools\' AND county name = \'Fremont\';\n')]

In [None]:
from run_evaluation import get_parsed_sql

parsed_path = proj_path / 'data' / 'pkl_files' 
file_name = f'bird_dev_parsed_pred.pkl'
if not (parsed_path / file_name).exists():
    pred_parsed, error_ids = get_parsed_sql(predictions, tables)
    with open(parsed_path / file_name, 'wb') as f:
        pickle.dump(pred_parsed, f)
    print(f'Error parsing pred bird_dev: {len(error_ids)}')

with (parsed_path / file_name).open('rb') as f:
    pred_parsed = pickle.load(f)

# Predict BO

In [117]:
def get_vector_store(bos: dict[str, list[dict[str, str]]]):
    documents = []
    for db_id, samples in bos.items():
        for x in samples:
            doc = Document(
                doc_id=x['sample_id'],
                page_content=x['ba'],
                metadata={
                    'sample_id': x['sample_id'],
                    'db_id': db_id,
                    'vt': x['vt']
                }
            )
            documents.append(doc)

    embeddings_model = OpenAIEmbeddings()
    vectorstore = FAISS.from_documents(
        documents, 
        embedding = embeddings_model,
    )
    return vectorstore

In [None]:
vectorstore = get_vector_store(res)

In [156]:
def predict_sql_bo(
    to_pred_samples: list[SpiderSample|BirdSample],
    tables: dict[DatabaseModel],
    vectorstore: FAISS,
    chain: RunnableSequence,
    prediction_path: Path,
    file_name: str = '[args.ds]_[args.type]',
    n_retrieval: int = 3,
    score_threshold: float = 0.65,
):
    processed_db_ids = [p.stem.split('_')[-1] for p in prediction_path.glob(f'{file_name}_*')]
    # restart from checkpoint
    if processed_db_ids:
        to_pred_samples = [sample for sample in to_pred_samples if sample.db_id not in processed_db_ids]
    
    samples_by_db_id = defaultdict(list)
    for sample in to_pred_samples:
        samples_by_db_id[sample.db_id].append(sample)

    for db_id, samples in samples_by_db_id.items():
        retriever = vectorstore.as_retriever(
            search_kwargs={'k': n_retrieval, 'score_threshold': score_threshold, 'filter': {'db_id': db_id}}
        )
        schema_str = get_schema_str(
            schema=tables[db_id].db_schema, 
            foreign_keys=tables[db_id].foreign_keys,
            col_explanation=tables[db_id].col_explanation
        )
        results = []
        for sample in tqdm(samples, total=len(samples), desc=f"{db_id}"):
            question = sample.final.question
            docs = retriever.invoke(question)
            hint = '\nDescriptions and Virtual Tables:\n'
            hint += json.dumps({j: {'description': doc.page_content, 'virtual_table': doc.metadata['vt']} for j, doc in enumerate(docs)}, indent=4)
            hint += '\n'
            input_data = {'schema': schema_str, 'input_query': question, 'hint': hint}
            output = chain.invoke(input=input_data)
            
            full_sql_output = {}
            full_sql_output['sample_id'] = sample.sample_id
            full_sql_output['rationale'] = output.rationale
            full_sql_output['pred_sql'] = output.full_sql_query
            # full_sql_output = 1
            results.append(full_sql_output)

        with open(prediction_path / f'{file_name}_{db_id}.json', 'w') as f:
            json.dump(results, f, indent=4)

In [157]:

# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_bo.json', 'r') as f:
#     bos = json.load(res, f, indent=4)
# vectorstore = get_vector_store(bos)


data_path = proj_path / 'data' / 'bird'
experiment_folder = proj_path / 'experiments' / 'bird'
prediction_path = experiment_folder / 'predictions' / 'zero_shot_hint'
eval_path = experiment_folder / 'evals'
for p in [prediction_path, eval_path]:
    if not p.exists():
        p.mkdir(parents=True)

prompt = PromptTemplate(
    template=Prompts.zero_shot_hints_inference,
    input_variables=['schema', 'input_query', 'hint'],
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
    frequency_penalty=0.1,
)

model = model_openai.with_structured_output(SQLResponse)
chain = (prompt | model)

n_retrieval = 3
score_threshold = 0.65

predict_sql_bo(
    to_pred_samples=dev_samples[:10],
    tables=bird_tables,
    vectorstore=vectorstore,
    chain=chain,
    prediction_path=prediction_path,
    n_retrieval=n_retrieval,
    score_threshold=score_threshold,
    file_name='bird_dev',
)

movie_platform: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]


In [138]:
docs = retriever.invoke(sample.final.question)
hint = '\nDescriptions and Virtual Tables:\n'
hint += json.dumps({j: {'description': doc.page_content, 'virtual_table': doc.metadata['vt']} for j, doc in enumerate(docs)}, indent=4)
hint += '\n'
input_data = {'schema': db_schema, 'input_query': row['question'], 'hint': hint}
output = chain.invoke(input=input_data)

print(hint)


Descriptions and Virtual Tables:
{
    "0": {
        "description": "The virtual table retrieves the titles of movies that have been rated, filtering by a specific rating timestamp and grouping the results by movie title. The results are ordered by the count of ratings for each movie title, and a limit is applied to restrict the number of returned titles.",
        "virtual_table": "SELECT movies.movie_title FROM ratings INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE ratings.rating_timestamp_utc LIKE '[placeholder-type:string]' GROUP BY movies.movie_title ORDER BY COUNT(movies.movie_title) LIMIT [placeholder-type:numeric]"
    },
    "1": {
        "description": "The virtual table provides a count of users who have rated a specific movie, identified by its title, while also filtering for users who were trialists at the time of rating. It combines data from the 'ratings' and 'movies' tables to achieve this.",
        "virtual_table": "SELECT COUNT(ratings.user_id) FROM ra

# Similarity between dataset

In [125]:
def get_parsed_sql(samples: dict, tables: dict):
    error_ids = []
    parsed = defaultdict(dict)
    iterator = tqdm(samples, total=len(samples))
    for sample in iterator:
        db_id = sample.db_id
        sample_id = sample.sample_id
        iterator.set_description(f"{db_id}")
        schema = Schema(tables[db_id].db_schema)
        sql_i = sample.final.sql
        try:
            ei = extract_all(sql_i, schema)
            assert len(ei['sel']) > 0, f'No selection found-{db_id}-{sample_id}'
        except Exception as e:
            error_ids.append((db_id, sample_id, str(e)))
            parsed[db_id].append(None)
            continue
        parsed[db_id][sample_id] = ei
    return parsed, error_ids

train_parsed, error_ids = get_parsed_sql(train_samples, bird_tables)
dev_parsed, error_ids = get_parsed_sql(dev_samples, bird_tables)
test_parsed, error_ids = get_parsed_sql(test_samples, bird_tables)

movie_platform:   0%|          | 0/6341 [00:00<?, ?it/s]

debit_card_specializing: 100%|██████████| 6341/6341 [00:21<00:00, 299.00it/s]    
debit_card_specializing: 100%|██████████| 2091/2091 [00:05<00:00, 408.98it/s]    
debit_card_specializing: 100%|██████████| 2193/2193 [00:09<00:00, 225.64it/s]    


In [129]:
# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_parsed.pkl', 'wb') as f:
#     pickle.dump(train_parsed, f)

# with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'wb') as f:
#     pickle.dump(dev_parsed, f)

# with open(proj_path / 'data' / 'pkl_files' / 'bird_test_parsed.pkl', 'wb') as f:
#     pickle.dump(test_parsed, f)

with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'rb') as f:
    dev_parsed = pickle.load(f)

with open(proj_path / 'data' / 'pkl_files' / 'bird_test_parsed.pkl', 'rb') as f:
    test_parsed = pickle.load(f)

In [None]:
from itertools import combinations, product
from collections import defaultdict
from src.eval_utils import get_all_partial_score

def measure_inter_score(parsed1: dict[str, tuple], parsed2: dict[str, tuple]):
    results = defaultdict()
    assert len(parsed1) == len(parsed2), f"Length mismatch-1: {len(parsed1)} 2:{len(parsed2)}"
    db_ids = list(parsed1.keys())
    for db_id in db_ids:
        o1 = parsed1[db_id]
        o2 = parsed2[db_id]
        n1 = len(o1)
        n2 = len(o2)
        semantic_sim = np.zeros((n1, n2), dtype=np.float32)
        structural_sim = np.zeros((n1, n2), dtype=np.float32)
        overall_sim = np.zeros((n1, n2), dtype=np.float32)

        idxs = list(product(range(n1), range(n2)))
        iterator = tqdm(idxs, total=len(idxs), desc=f"{db_id}")
        for i, j in iterator:
            ei = o1[i]
            ej = o2[j]

            _, final_score = get_all_partial_score(ei, ej, use_bert=True)

            structural_sim[i, j] = final_score['structural']
            semantic_sim[i, j] = final_score['semantic']
            overall_sim[i, j] = final_score['overall']

        results[db_id] = {
            'semantic': semantic_sim,
            'struct': structural_sim,
            'overall': overall_sim
        }
    return results

results = measure_inter_score(dev_parsed, test_parsed)
with (proj_path / 'data' / 'pkl_files' / 'bird_dev_test_similarity.pkl').open('wb') as f:
    pickle.dump(results, f)

# Complexity between datasets

In [12]:
def measure_complexity(samples, tables):
    cs = []
    for s in tqdm(samples, total=len(samples)):
        schema = Schema(tables[s.db_id].db_schema)
        output = extract_all(s.final.sql, schema)
        complexity = get_complexity(output)
        cs.append(complexity)
    return cs

train_complexities = measure_complexity(train_samples, bird_tables)
dev_complexities = measure_complexity(dev_samples, bird_tables)
test_complexities = measure_complexity(test_samples, bird_tables)

100%|██████████| 6341/6341 [00:10<00:00, 631.17it/s]
100%|██████████| 2091/2091 [00:03<00:00, 589.55it/s]
100%|██████████| 2193/2193 [00:03<00:00, 591.64it/s]


In [22]:
for c, n in zip([train_complexities, dev_complexities, test_complexities], ['train', 'dev  ', 'test ']):
    print(f'[{n}] Mean={np.mean(c):.4f} +/-{np.std(c):.4f}, Median={np.median(c):.4f}')

[train] Mean=0.2753 +/-0.0476, Median=0.2710
[dev  ] Mean=0.2758 +/-0.0471, Median=0.2710
[test ] Mean=0.2760 +/-0.0477, Median=0.2709


In [None]:
stats = defaultdict(list)
for s in dev_samples:
    stats[s.db_id].append(s)