In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# seed = 42
# all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)

# with open(proj_path / 'data' / 'bird_skip.txt') as f:
#     skip = [int(line.strip()) for line in f]

# bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
# train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2, seed=seed)

# save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
# save_samples_spider_bird(dev_samples, proj_path / 'data' / 'bird_dev.json')
# save_samples_spider_bird(test_samples, proj_path / 'data' / 'bird_test.json')
# print(len(train_samples), len(dev_samples), len(test_samples))

In [3]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.runnables import RunnableSequence
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import PromptTemplate


_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import (
    DatabaseModel, 
    SpiderSample, 
    BirdSample, 
    BODescription,
    SQLResponse
)
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)

from src.parsing_sql import Schema, extract_all
from src.eval_utils import get_complexity, result_eq, check_if_exists_orderby
from run_bo_sql import get_vector_store
from copy import deepcopy
bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)
train_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_train.json')
dev_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_dev.json')
test_samples = load_samples_spider_bird(proj_path / 'data' / 'bird_test.json')

In [119]:
# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_parsed.pkl', 'rb') as f:
#     train_parsed = pickle.load(f)

# # prediction parsed
# with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'rb') as f:
#     dev_parsed = pickle.load(f)

In [17]:
eval_path = proj_path / 'experiments' / 'bird' / 'evals' / 'zero_shot'

df = []
for p in eval_path.glob('bird_dev_*.json'):
    with p.open() as f:
        for line in f:
            eval_data = json.loads(line)
            df.append(eval_data)

df = pd.DataFrame(df)
df.to_csv(eval_path / 'bird_dev.csv', index=False)

In [16]:
df['gold_complexity'].agg(['mean', 'std', 'min', 'max', 'median'])

mean      0.450361
std       0.055482
min       0.318118
max       0.726155
median    0.446118
Name: gold_complexity, dtype: float64

In [31]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'create_bo'
bos = defaultdict(list)
for p in prediction_path.glob('bird_train_bo_*.json'):
    with p.open() as f:
        temp = json.load(f)
    
    bos[p.stem.split('_', 3)[-1]] = temp

# with (prediction_path / 'final_bird_train_bo.json').open('w') as f:
#     json.dump(bos, f, indent=4)

In [49]:
vectorstore = get_vector_store({'address': bos['address']})

In [51]:
results = vectorstore.similarity_search_with_score(
    "Will it be hot tomorrow?", k=1
)

In [52]:
results

[(Document(metadata={'sample_id': 5169, 'db_id': 'address', 'vt': "SELECT COUNT(zip_data.zip_code) FROM state INNER JOIN zip_data AS T2 ON T1.abbreviation = T2.state WHERE state.name = '[placeholder-type:string]' AND zip_data.daylight_savings = '[placeholder-type:string]' AND zip_data.region = '[placeholder-type:string]'"}, page_content="The virtual table counts the number of zip codes from the 'zip_data' table that are associated with a specific state, while also filtering for those that observe daylight savings and belong to a particular region. The placeholders represent the state name, daylight savings status, and region respectively."),
  0.59726036)]

In [39]:
from langchain_core.runnables import chain

@chain
def retriever(query: str) -> list[Document]:
    docs, scores = zip(*vectorstore.similarity_search_with_score(query))
    for doc, score in zip(docs, scores):
        doc.metadata["score"] = score

    return docs

In [55]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
base_retriever = vectorstore.as_retriever(
    search_type='similarity_score_threshold', search_kwargs={'score_threshold': 0.6}
)
model = HuggingFaceCrossEncoder(model_name='cross-encoder/ms-marco-MiniLM-L-6-v2')
compressor = CrossEncoderReranker(model=model, top_n=3)
retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=base_retriever
)

In [56]:
x = base_retriever.invoke('How many zip data are there?')

In [58]:
x[0].metadata

{'sample_id': 5189,
 'db_id': 'address',
 'vt': 'SELECT DISTINCT zip_data.state FROM state INNER JOIN zip_data AS T2 ON T1.abbreviation = T2.state WHERE zip_data.female_population > (SELECT AVG(zip_data.female_population) FROM zip_data)'}

In [20]:
ContextualCompressionRetriever.__bases__

(langchain_core.retrievers.BaseRetriever,)

In [71]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'zero_shot'
predictions = []
for p in prediction_path.glob('bird_dev_*.json'):
    with open(p) as f:
        pred = json.load(f)
        new_pred = []
        for x in pred:
            x.pop('rationale')
            new_pred.append(x)
        predictions.extend(new_pred)

with open(prediction_path / 'final_bird_dev.jsonl', 'w') as f:
    for p in predictions:
        f.write(json.dumps(p) + '\n')

len(predictions)

2091

In [69]:
predictions[0]

{'sample_id': 8773,
 'db_id': 'food_inspection',
 'gold_sql': "SELECT COUNT(owner_state) FROM businesses WHERE owner_state = 'CA'",
 'pred_sql': "SELECT COUNT(DISTINCT owner_name) AS owner_count FROM businesses WHERE owner_state = 'California';"}

In [206]:
# typ = 'test'
# samples = test_samples
# predictions = []
# with open(prediction_path / f'final_bird_{typ}.jsonl', 'r') as f:
#     preds = f.readlines()
#     for p in preds:
#         pred = json.loads(p)
#         found = False
#         for sample in samples:
#             if sample.sample_id == pred['sample_id']:
#                 pred['gold_sql'] = sample.final.sql
#                 found = True
#                 break
#         if not found:
#             raise ValueError(f"Sample ID {pred['sample_id']} not found")
        
#         predictions.append(pred)

# with open(prediction_path / f'final_bird_{typ}.jsonl', 'w') as f:
#     for p in predictions:
#         f.write(json.dumps(p) + '\n')

In [3]:
import sqlglot
import sqlglot.expressions as exp
from src.eval_utils import result_eq, check_if_exists_orderby

def get_pred_results(
        proj_path: Path,
        predictions: list[dict],
        tables: dict[str, DatabaseModel],
        ds: str = 'bird' # spider or bird
    ):
    output_results = []
    error_infos = {
        'pred_exec': [],
        'gold_exec': [],
        'python_script': [],
        'result': []
    }
    predictions_by_db_id = defaultdict(list)
    for pred in predictions:
        predictions_by_db_id[pred['db_id']].append(pred)
    
    for db_id, preds in predictions_by_db_id.items():
        schema = Schema(tables[db_id].db_schema)
        if ds == 'bird':
            try:
                database = SqliteDatabase(
                    db_file=str(proj_path / 'data' / ds / 'train' / 'train_databases' / db_id / f'{db_id}.sqlite'),
                    foreign_keys=tables[db_id].foreign_keys
                )
            except:
                database = SqliteDatabase(
                    db_file=str(proj_path / 'data' / ds / 'dev' / 'dev_databases' / db_id / f'{db_id}.sqlite'),
                    foreign_keys=tables[db_id].foreign_keys
                )
        else:
            database = SqliteDatabase(
                db_file=str(proj_path / 'data' / 'spider' / 'database' / db_id / f'{db_id}.sqlite'), 
                foreign_keys=tables[db_id].foreign_keys
            )
        iterator = tqdm(preds, total=len(preds))
        for pred in iterator:
            iterator.set_description(f'{db_id} | pred_exec: {len(error_infos["pred_exec"])} | gold_exec: {len(error_infos["gold_exec"])} | python_script: {len(error_infos["python_script"])} | result: {len(error_infos["result"])}')

            pred_sql = pred['pred_sql'] 
            gold_sql = pred['gold_sql']
            
            error_info = ''
            try:
                pred_result = database.execute(pred_sql, rt_pandas=False)
            except Exception as e:
                pred_result = []
                error_infos['pred_exec'].append((db_id, pred['sample_id']))
                error_info = 'Predction Execution Error:' + str(e)
                score = 0

            try:
                gold_result = database.execute(gold_sql, rt_pandas=False)
            except Exception as e:
                error_infos['gold_exec'].append((db_id, pred['sample_id']))
                error_info = 'Gold Execution Error:' + str(e)
            
            if 'Gold Execution Error' in error_info:
                continue
            elif 'Predction Execution Error' in error_info:
                output_results.append(
                    {
                        'sample_id': pred['sample_id'], 
                        'db_id': db_id,
                        'score': score,
                        'gold_sql': gold_sql,
                        'pred_sql': pred_sql,
                    }
                )
                continue
            else:
                exists_orderby = check_if_exists_orderby(gold_sql)
                
                try:
                    score = int(result_eq(pred_result, gold_result, order_matters=exists_orderby))
                except Exception as e:
                    print(f"An error occurred: {e}")
                    score = 0
                    error_info = 'Python Script Error:' + str(e)
                    error_infos['python_script'].append((db_id, pred['sample_id']))

                if score == 0 and error_info == '':
                    error_info = 'Result not equal'
                    error_infos['result'].append((db_id, pred['sample_id']))
                output_results.append(
                    {
                        'sample_id': pred['sample_id'], 
                        'db_id': db_id,
                        'score': score,
                        'gold_sql': gold_sql,
                        'pred_sql': pred_sql,
                    }
                )

    return output_results, error_infos


In [4]:
prediction_path = proj_path / 'experiments' / 'bird' / 'predictions' / 'zero_shot'

def load_predictions(prediction_path: Path, filename: str):
    predictions = []
    with open(prediction_path / filename, 'r') as f:
        preds = f.readlines()
        for p in preds:
            predictions.append(json.loads(p))
    return predictions
predictions = []
with open(prediction_path / 'final_bird_dev.jsonl', 'r') as f:
    preds = f.readlines()
    for p in preds:
        predictions.append(json.loads(p))

# output_results, error_infos = get_pred_results(proj_path, predictions, tables, ds='bird')

In [59]:
from run_evaluation import get_prediction_parsed_sql, get_target_parsed_sql
from src.eval_utils import get_complexity, get_all_partial_score

In [None]:
target_parsed, error_ids = get_target_parsed_sql([x for x in dev_samples if x.sample_id in [9503, 9468, 9494]], tables=bird_tables)
pred_parsed, error_ids = get_prediction_parsed_sql(predictions[:3], tables=bird_tables)

california_schools: 100%|██████████| 3/3 [00:00<00:00, 456.88it/s]


california_schools: 100%|██████████| 3/3 [00:00<00:00, 444.52it/s]


In [57]:
db_id = 'california_schools'

In [60]:
for pred in predictions[:3]:
    sample_id = pred['sample_id']
    target_o = target_parsed[db_id][sample_id]
    pred_o = pred_parsed[db_id][sample_id]
    if pred_o:
        _, all_score = get_all_partial_score(pred_o, target_o)

        print(f"Sample ID: {sample_id}")
        print(f'f1={all_score["overall"]} | {all_score["structural"]:.4f} | {all_score["semantic"]:.4f}')

Sample ID: 9503
f1=0.7798 | 0.7657 | 0.8017
Sample ID: 9468
f1=0.4271 | 0.4170 | 0.5096


In [58]:
target_parsed[db_id]

{9503: defaultdict(set,
             {'aliases': {'table': {'T1': 'frpm', 'T2': 'schools'},
               'column': {}},
              'distinct': False,
              'limit': True,
              'table_asts': {('frpm',
                Table(
                  this=Table(
                    this=Identifier(this=frpm, quoted=False))),
                'from'),
               ('schools',
                Table(
                  this=Table(
                    this=Identifier(this=schools, quoted=False))),
                'join')},
              'sel': {'__frpm.low grade__',
               '__frpm.school name__',
               '__schools.city__'},
              'sel_asts': {('__frpm.low grade__',
                Column(
                  this=Identifier(this=low grade, quoted=True),
                  table=Identifier(this=frpm, quoted=False)),
                '<select>'),
               ('__frpm.school name__',
                Column(
                  this=Identifier(this=school name,

In [26]:
predictions[2]

{'sample_id': 9494,
 'rationale': ['Identify the relevant table: frpm contains the enrollment data.',
  "Determine the columns needed: 'enrollment (ages 5-17)' for the number of students, 'academic year' to filter by the specific year, and 'district type' to specify the type of school.",
  'Filter the records for the academic year 2014-2015.',
  "Filter for the district type 'State Special Schools'.",
  "Filter for the county name 'Fremont'.",
  'Use SUM() to aggregate the total enrollment for the specified filters.'],
 'pred_sql': "SELECT SUM(enrollment (ages 5-17)) AS total_enrollment\nFROM frpm\nWHERE academic year = '2014-2015' AND district type = 'State Special Schools' AND county name = 'Fremont';",
 'db_id': 'california_schools',
 'gold_sql': 'SELECT T1."Enrollment (Ages 5-17)" FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.EdOpsCode = \'SSS\' AND T2.City = \'Fremont\' AND T1."Academic Year" BETWEEN 2014 AND 2015'}

In [35]:
sql = """
SELECT SUM("enrollment (ages 5-17)") AS total_enrollment
FROM frpm
WHERE "academic year" = '2014-2015' AND "district type" = 'State Special Schools' AND "county name" = 'Fremont';
"""
x = {'sample_id': 9494,
 'rationale': ['Identify the relevant table: frpm contains the enrollment data.',
  "Determine the columns needed: 'enrollment (ages 5-17)' for the number of students, 'academic year' to filter by the specific year, and 'district type' to specify the type of school.",
  'Filter the records for the academic year 2014-2015.',
  "Filter for the district type 'State Special Schools'.",
  "Filter for the county name 'Fremont'.",
  'Use SUM() to aggregate the total enrollment for the specified filters.'],
 'pred_sql': sql,
 'db_id': 'california_schools',
 'gold_sql': 'SELECT T1."Enrollment (Ages 5-17)" FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.EdOpsCode = \'SSS\' AND T2.City = \'Fremont\' AND T1."Academic Year" BETWEEN 2014 AND 2015'}

parsed, error_ids = get_prediction_parsed_sql([x], tables=bird_tables)

california_schools: 100%|██████████| 1/1 [00:00<00:00, 465.31it/s]


In [33]:
db_id = 'california_schools'
tables = bird_tables
db_schema = get_schema_str(
    schema=tables[db_id].db_schema, 
    foreign_keys=tables[db_id].foreign_keys,
    col_explanation=tables[db_id].col_explanation
)

In [29]:
database = SqliteDatabase(
    db_file=str(proj_path / 'data' / 'bird' / 'dev' / 'dev_databases' / 'california_schools' / 'california_schools.sqlite'),
    foreign_keys=bird_tables['california_schools'].foreign_keys
)

In [30]:
database.execute("""
SELECT "enrollment (ages 5-17)"
FROM frpm
LIMIT 1
""")

Unnamed: 0,Enrollment (Ages 5-17)
0,1070.0


In [28]:
error_ids

[('california_schools',
  9494,
  'Invalid expression / Unexpected token. Line 4, Col: 19.\n  \nSELECT SUM("enrollment (ages 5-17)") AS total_enrollment\nFROM frpm\nWHERE academic \x1b[4myear\x1b[0m = \'2014-2015\' AND district type = \'State Special Schools\' AND county name = \'Fremont\';\n')]

In [None]:
from run_evaluation import get_parsed_sql

parsed_path = proj_path / 'data' / 'pkl_files' 
file_name = f'bird_dev_parsed_pred.pkl'
if not (parsed_path / file_name).exists():
    pred_parsed, error_ids = get_parsed_sql(predictions, tables)
    with open(parsed_path / file_name, 'wb') as f:
        pickle.dump(pred_parsed, f)
    print(f'Error parsing pred bird_dev: {len(error_ids)}')

with (parsed_path / file_name).open('rb') as f:
    pred_parsed = pickle.load(f)

# Create VT, BA

In [111]:
import sqlglot
from collections import defaultdict
from src.parsing_sql import (
    Schema, _format_expression, extract_aliases
)
from run_bo_sql import create_vt_ba

In [112]:
prompt = PromptTemplate(
    template=Prompts.bo_description,
    input_variables=['schema', 'virtual_table']
)
model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
    frequency_penalty=0.1,
)
model = model_openai.with_structured_output(BODescription)
chain = (prompt | model)
res = create_vt_ba(samples=train_samples[:10], tables=bird_tables, chain=chain)

100%|██████████| 10/10 [00:32<00:00,  3.20s/it]


# Predict BO

In [117]:
def get_vector_store(bos: dict[str, list[dict[str, str]]]):
    documents = []
    for db_id, samples in bos.items():
        for x in samples:
            doc = Document(
                doc_id=x['sample_id'],
                page_content=x['ba'],
                metadata={
                    'sample_id': x['sample_id'],
                    'db_id': db_id,
                    'vt': x['vt']
                }
            )
            documents.append(doc)

    embeddings_model = OpenAIEmbeddings()
    vectorstore = FAISS.from_documents(
        documents, 
        embedding = embeddings_model,
    )
    return vectorstore

In [None]:
vectorstore = get_vector_store(res)

In [156]:
def predict_sql_bo(
    to_pred_samples: list[SpiderSample|BirdSample],
    tables: dict[DatabaseModel],
    vectorstore: FAISS,
    chain: RunnableSequence,
    prediction_path: Path,
    file_name: str = '[args.ds]_[args.type]',
    n_retrieval: int = 3,
    score_threshold: float = 0.65,
):
    processed_db_ids = [p.stem.split('_')[-1] for p in prediction_path.glob(f'{file_name}_*')]
    # restart from checkpoint
    if processed_db_ids:
        to_pred_samples = [sample for sample in to_pred_samples if sample.db_id not in processed_db_ids]
    
    samples_by_db_id = defaultdict(list)
    for sample in to_pred_samples:
        samples_by_db_id[sample.db_id].append(sample)

    for db_id, samples in samples_by_db_id.items():
        retriever = vectorstore.as_retriever(
            search_kwargs={'k': n_retrieval, 'score_threshold': score_threshold, 'filter': {'db_id': db_id}}
        )
        schema_str = get_schema_str(
            schema=tables[db_id].db_schema, 
            foreign_keys=tables[db_id].foreign_keys,
            col_explanation=tables[db_id].col_explanation
        )
        results = []
        for sample in tqdm(samples, total=len(samples), desc=f"{db_id}"):
            question = sample.final.question
            docs = retriever.invoke(question)
            hint = '\nDescriptions and Virtual Tables:\n'
            hint += json.dumps({j: {'description': doc.page_content, 'virtual_table': doc.metadata['vt']} for j, doc in enumerate(docs)}, indent=4)
            hint += '\n'
            input_data = {'schema': schema_str, 'input_query': question, 'hint': hint}
            output = chain.invoke(input=input_data)
            
            full_sql_output = {}
            full_sql_output['sample_id'] = sample.sample_id
            full_sql_output['rationale'] = output.rationale
            full_sql_output['pred_sql'] = output.full_sql_query
            # full_sql_output = 1
            results.append(full_sql_output)

        with open(prediction_path / f'{file_name}_{db_id}.json', 'w') as f:
            json.dump(results, f, indent=4)

In [157]:

# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_bo.json', 'r') as f:
#     bos = json.load(res, f, indent=4)
# vectorstore = get_vector_store(bos)


data_path = proj_path / 'data' / 'bird'
experiment_folder = proj_path / 'experiments' / 'bird'
prediction_path = experiment_folder / 'predictions' / 'zero_shot_hint'
eval_path = experiment_folder / 'evals'
for p in [prediction_path, eval_path]:
    if not p.exists():
        p.mkdir(parents=True)

prompt = PromptTemplate(
    template=Prompts.zero_shot_hints_inference,
    input_variables=['schema', 'input_query', 'hint'],
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
    frequency_penalty=0.1,
)

model = model_openai.with_structured_output(SQLResponse)
chain = (prompt | model)

n_retrieval = 3
score_threshold = 0.65

predict_sql_bo(
    to_pred_samples=dev_samples[:10],
    tables=bird_tables,
    vectorstore=vectorstore,
    chain=chain,
    prediction_path=prediction_path,
    n_retrieval=n_retrieval,
    score_threshold=score_threshold,
    file_name='bird_dev',
)

movie_platform: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]


In [138]:
docs = retriever.invoke(sample.final.question)
hint = '\nDescriptions and Virtual Tables:\n'
hint += json.dumps({j: {'description': doc.page_content, 'virtual_table': doc.metadata['vt']} for j, doc in enumerate(docs)}, indent=4)
hint += '\n'
input_data = {'schema': db_schema, 'input_query': row['question'], 'hint': hint}
output = chain.invoke(input=input_data)

print(hint)


Descriptions and Virtual Tables:
{
    "0": {
        "description": "The virtual table retrieves the titles of movies that have been rated, filtering by a specific rating timestamp and grouping the results by movie title. The results are ordered by the count of ratings for each movie title, and a limit is applied to restrict the number of returned titles.",
        "virtual_table": "SELECT movies.movie_title FROM ratings INNER JOIN movies AS T2 ON T1.movie_id = T2.movie_id WHERE ratings.rating_timestamp_utc LIKE '[placeholder-type:string]' GROUP BY movies.movie_title ORDER BY COUNT(movies.movie_title) LIMIT [placeholder-type:numeric]"
    },
    "1": {
        "description": "The virtual table provides a count of users who have rated a specific movie, identified by its title, while also filtering for users who were trialists at the time of rating. It combines data from the 'ratings' and 'movies' tables to achieve this.",
        "virtual_table": "SELECT COUNT(ratings.user_id) FROM ra

# Similarity between dataset

In [125]:
def get_parsed_sql(samples: dict, tables: dict):
    error_ids = []
    parsed = defaultdict(dict)
    iterator = tqdm(samples, total=len(samples))
    for sample in iterator:
        db_id = sample.db_id
        sample_id = sample.sample_id
        iterator.set_description(f"{db_id}")
        schema = Schema(tables[db_id].db_schema)
        sql_i = sample.final.sql
        try:
            ei = extract_all(sql_i, schema)
            assert len(ei['sel']) > 0, f'No selection found-{db_id}-{sample_id}'
        except Exception as e:
            error_ids.append((db_id, sample_id, str(e)))
            parsed[db_id].append(None)
            continue
        parsed[db_id][sample_id] = ei
    return parsed, error_ids

train_parsed, error_ids = get_parsed_sql(train_samples, bird_tables)
dev_parsed, error_ids = get_parsed_sql(dev_samples, bird_tables)
test_parsed, error_ids = get_parsed_sql(test_samples, bird_tables)

movie_platform:   0%|          | 0/6341 [00:00<?, ?it/s]

debit_card_specializing: 100%|██████████| 6341/6341 [00:21<00:00, 299.00it/s]    
debit_card_specializing: 100%|██████████| 2091/2091 [00:05<00:00, 408.98it/s]    
debit_card_specializing: 100%|██████████| 2193/2193 [00:09<00:00, 225.64it/s]    


In [129]:
# with open(proj_path / 'data' / 'pkl_files' / 'bird_train_parsed.pkl', 'wb') as f:
#     pickle.dump(train_parsed, f)

# with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'wb') as f:
#     pickle.dump(dev_parsed, f)

# with open(proj_path / 'data' / 'pkl_files' / 'bird_test_parsed.pkl', 'wb') as f:
#     pickle.dump(test_parsed, f)

with open(proj_path / 'data' / 'pkl_files' / 'bird_dev_parsed.pkl', 'rb') as f:
    dev_parsed = pickle.load(f)

with open(proj_path / 'data' / 'pkl_files' / 'bird_test_parsed.pkl', 'rb') as f:
    test_parsed = pickle.load(f)

In [None]:
from itertools import combinations, product
from collections import defaultdict
from src.eval_utils import get_all_partial_score

def measure_inter_score(parsed1: dict[str, tuple], parsed2: dict[str, tuple]):
    results = defaultdict()
    assert len(parsed1) == len(parsed2), f"Length mismatch-1: {len(parsed1)} 2:{len(parsed2)}"
    db_ids = list(parsed1.keys())
    for db_id in db_ids:
        o1 = parsed1[db_id]
        o2 = parsed2[db_id]
        n1 = len(o1)
        n2 = len(o2)
        semantic_sim = np.zeros((n1, n2), dtype=np.float32)
        structural_sim = np.zeros((n1, n2), dtype=np.float32)
        overall_sim = np.zeros((n1, n2), dtype=np.float32)

        idxs = list(product(range(n1), range(n2)))
        iterator = tqdm(idxs, total=len(idxs), desc=f"{db_id}")
        for i, j in iterator:
            ei = o1[i]
            ej = o2[j]

            _, final_score = get_all_partial_score(ei, ej, use_bert=True)

            structural_sim[i, j] = final_score['structural']
            semantic_sim[i, j] = final_score['semantic']
            overall_sim[i, j] = final_score['overall']

        results[db_id] = {
            'semantic': semantic_sim,
            'struct': structural_sim,
            'overall': overall_sim
        }
    return results

results = measure_inter_score(dev_parsed, test_parsed)
with (proj_path / 'data' / 'pkl_files' / 'bird_dev_test_similarity.pkl').open('wb') as f:
    pickle.dump(results, f)

# Complexity between datasets

In [12]:
def measure_complexity(samples, tables):
    cs = []
    for s in tqdm(samples, total=len(samples)):
        schema = Schema(tables[s.db_id].db_schema)
        output = extract_all(s.final.sql, schema)
        complexity = get_complexity(output)
        cs.append(complexity)
    return cs

train_complexities = measure_complexity(train_samples, bird_tables)
dev_complexities = measure_complexity(dev_samples, bird_tables)
test_complexities = measure_complexity(test_samples, bird_tables)

100%|██████████| 6341/6341 [00:10<00:00, 631.17it/s]
100%|██████████| 2091/2091 [00:03<00:00, 589.55it/s]
100%|██████████| 2193/2193 [00:03<00:00, 591.64it/s]


In [22]:
for c, n in zip([train_complexities, dev_complexities, test_complexities], ['train', 'dev  ', 'test ']):
    print(f'[{n}] Mean={np.mean(c):.4f} +/-{np.std(c):.4f}, Median={np.median(c):.4f}')

[train] Mean=0.2753 +/-0.0476, Median=0.2710
[dev  ] Mean=0.2758 +/-0.0471, Median=0.2710
[test ] Mean=0.2760 +/-0.0477, Median=0.2709


In [None]:
stats = defaultdict(list)
for s in dev_samples:
    stats[s.db_id].append(s)