In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import sys
from pathlib import Path
proj_path = Path('.').resolve()
sys.path.append(str(proj_path))

import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Optional
from collections import defaultdict
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableSequence
from langchain_core.prompts import PromptTemplate

_ = load_dotenv(find_dotenv())

from src.db_utils import get_schema_str, get_data_dict
from src.pymodels import DatabaseModel, QuestionSQL, SparcSample, SpiderSample, BirdSample, Description
from src.prompts import Prompts
from src.database import SqliteDatabase
from src.data_preprocess import (
    load_raw_data,
    process_all_tables,
    filter_samples_by_count_spider_bird,
    process_samples_bird,
    split_train_dev_test,
    save_samples_spider_bird,
    load_samples_spider_bird,
)

from src.parsing_sql import Schema, extract_all
from src.eval_utils import get_complexity

from copy import deepcopy
bird_path = proj_path / 'data' / 'bird'
tables, train_data, dev_data = load_raw_data(bird_path, load_test=False)

with (proj_path / 'data' / 'bird_description.json').open() as f:
    all_descriptions = json.load(f)

bird_tables = process_all_tables(tables, descriptions=all_descriptions)

In [3]:
seed = 42
experiment_folder = proj_path / 'experiments' / f'bird_{seed}'
all_data = filter_samples_by_count_spider_bird(train_data+dev_data, n=10)

with open(proj_path / 'data' / 'bird_skip.txt') as f:
    skip = [int(line.strip()) for line in f]

bird_samples = process_samples_bird(all_data, bird_tables, skip=skip)
train_samples, dev_samples, test_samples = split_train_dev_test(bird_samples, train_ratio=0.6, dev_ratio=0.2)

save_samples_spider_bird(train_samples, proj_path / 'data' / 'bird_train.json')
save_samples_spider_bird(dev_samples, proj_path / 'data' / 'bird_dev.json')
save_samples_spider_bird(test_samples, proj_path / 'data' / 'bird_test.json')
print(len(train_samples), len(dev_samples), len(test_samples))

  0%|          | 0/10956 [00:00<?, ?it/s]

100%|██████████| 10956/10956 [00:03<00:00, 2939.89it/s]

6341 2091 2193





In [12]:
def measure_complexity(samples, tables):
    cs = []
    for s in tqdm(samples, total=len(samples)):
        schema = Schema(tables[s.db_id].db_schema)
        output = extract_all(s.final.sql, schema)
        complexity = get_complexity(output)
        cs.append(complexity)
    return cs

train_complexities = measure_complexity(train_samples, bird_tables)
dev_complexities = measure_complexity(dev_samples, bird_tables)
test_complexities = measure_complexity(test_samples, bird_tables)

100%|██████████| 6341/6341 [00:10<00:00, 631.17it/s]
100%|██████████| 2091/2091 [00:03<00:00, 589.55it/s]
100%|██████████| 2193/2193 [00:03<00:00, 591.64it/s]


In [22]:
for c, n in zip([train_complexities, dev_complexities, test_complexities], ['train', 'dev  ', 'test ']):
    print(f'[{n}] Mean={np.mean(c):.4f} +/-{np.std(c):.4f}, Median={np.median(c):.4f}')

[train] Mean=0.2753 +/-0.0476, Median=0.2710
[dev  ] Mean=0.2758 +/-0.0471, Median=0.2710
[test ] Mean=0.2760 +/-0.0477, Median=0.2709
