* gpt-4o-mini
* gemini-1.5-pro
* gemini-1.5-flash

In [7]:
# create database
import duckdb
from src.database import Database

import json
from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict
from tqdm import tqdm

data_type = 'train'
def load_sparc_data(data_path: Path):
    with (data_path / f'tables.json').open() as f:
        data_tables = json.load(f)
    with (data_path / f'train.json').open() as f:
        train_data = json.load(f)
    with (data_path / f'dev.json').open() as f:
        dev_data = json.load(f)
    return data_tables, train_data, dev_data

duckdb.sql('INSTALL sqlite')

proj_path = Path('.').resolve()
sparc_path = proj_path / 'data' / 'sparc'

db_id = 'hospital_1'
db = Database(db_file=str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'))

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

Number of train: 3034 | Number of dev: 422


In [8]:
from collections import defaultdict
from pydantic import BaseModel

class DatabaseModel(BaseModel):
    db_id: str
    db_schema: dict[str, dict[str, str]]
    col_explanation: dict[str, str]
    foreign_keys: list[str]
    primary_keys: list[str]

class QuestionSQL(BaseModel):
    question: str
    sql: str

class SparcSample(BaseModel):
    sample_id: int = -1
    db_id: str
    interactions: list[QuestionSQL]
    final: QuestionSQL

def preprocess_sql(sql: str) -> str:
    return sql.replace('"', "'").strip()

def process_all_tables(tables: list) -> dict[str, DatabaseModel]:
    database = defaultdict(DatabaseModel)
    for table in tables:
        db_id = table['db_id']
        data_dict = get_data_dict(table)
        database[db_id] = DatabaseModel(
            db_id=db_id,
            db_schema=data_dict['schema'],
            col_explanation=data_dict['col_explanation'],
            foreign_keys=data_dict['foreign_keys'],
            primary_keys=data_dict['primary_keys']
        )
    return database

def filter_samples_by_count(all_data: dict, n: int=5) -> list:
    counter = defaultdict(int)
    for data in all_data:
        db_id = data['database_id']
        counter[db_id] += 1
    all_data = list(filter(lambda x: counter[x['database_id']] >= n, all_data))
    return all_data

def process_samples(all_data: list) -> dict[str, list[SparcSample]]:
    data_by_db_id = defaultdict(list)
    for i, data in enumerate(all_data):
        db_id = data['database_id']
        sample = SparcSample(
            sample_id=i,
            db_id=db_id,
            interactions=[
                QuestionSQL(question=x['utterance'], sql=preprocess_sql(x['query'])) for x in data['interaction']
            ],
            final=QuestionSQL(
                question=data['final']['utterance'], 
                sql=preprocess_sql(data['final']['query']), 
            )
        )
        data_by_db_id[db_id].append(sample)
    return data_by_db_id

def split_train_dev(sparc_samples: dict, ratio: float=0.8):
    train_samples = []
    dev_samples = []
    for db_id, samples in sparc_samples.items():
        n_train = int(len(samples) * ratio)
        assert len(samples[n_train:]) > 0, f'Not enough samples for dev set: {db_id}'
        train_samples.extend(samples[:n_train])
        dev_samples.extend(samples[n_train:])
    return train_samples, dev_samples

sparc_tables = process_all_tables(tables)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

## Augmentation

In [9]:
from dotenv import load_dotenv, find_dotenv

from langchain_core.pydantic_v1 import BaseModel as LCBaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

In [4]:
def get_sparc_schema_description(proj_path: Path, sparc_tables: dict) -> dict:

    class Description(LCBaseModel):
        output: dict[str, dict[str, str]] = Field(description='Description of each column for all tables in the database')

    template = '''### Task
    You are tasked with writing one line short description for each column name in a database to help users understand the data better.
    You will be proveded a schema with table names and column names.

    ### Formatting
    Your output should be of the following JSON format with `output` key and value as a dictionary of table names and column names with their descriptions.:
    {{
        "<table_name1>" : {{
            "<column_name>": <str: the one line short description of column>,
            ...
        }},
        ...
    }} 

    ### Output
    <SCHEMA>:\n{schema}
    <OUTPUT>: 
    '''

    prompt = PromptTemplate(
        template=template,
        input_variables=['schema']
    )

    model_openai = ChatOpenAI(
        model='gpt-4o-mini',
        temperature=0.2,
    )

    chain = (prompt | model_openai | JsonOutputParser(pydantic_object=Description))

    all_descriptions = {}
    for db_id, database_model in tqdm(sparc_tables.items(), total=len(sparc_tables)):
        schema_desc = chain.invoke(input={'schema': get_schema_str(database_model.db_schema)})
        all_descriptions[db_id] = schema_desc

    with (proj_path / 'db_data' / 'sparc_description.json').open('w') as f:
        json.dump(all_descriptions, f, indent=4)

# get_sparc_schema_description(proj_path, sparc_tables)

In [5]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(schema=sparc_tables['hospital_1'].db_schema, col_explanation=all_descriptions['hospital_1']))

Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'DepartmentID'(number): Unique identifier for each department.
  - 'Name'(number): Name of the department.
  - 'Head'(number): Identifier for the head of the department.
Table Name: Affiliated_With
  - 'Physician'(text): Identifier for the physician.
  - 'Department'(text): Identifier for the department.
  - 'PrimaryAffiliation'(text): Indicates if this is the primary affiliation.
Table Name: Procedures
  - 'Code'(text): Unique code for each medical procedure.
  - 'Name'(text): Name of the medical procedure.
  - 'Cost'(text): Cost associated with the procedure.
Table Name: Trained_In
  - 'Physician'(number): Identifier for the physician.
  - 'Treatment'(number): Identifier for the treatment.
  - 'CertificationDate'(nu