## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict
from src.database import SqliteDatabase, DuckDBDatabase
from src.sparc_preprocess import (
    load_sparc_data,
    process_all_tables, 
    filter_samples_by_count_sparc,
    filter_samples_by_count_spider, 
    process_samples_sparc,
    process_samples_spider, 
    split_train_dev
)

proj_path = Path('.').resolve()

## Sparc Dataset

In [2]:
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

with (proj_path / 'db_data' / 'description.json').open() as f:
    all_descriptions = json.load(f)

sparc_tables = process_all_tables(tables, descriptions=all_descriptions)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count_sparc(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples_sparc(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

Number of train: 3034 | Number of dev: 422


In [3]:
db_id = 'hospital_1'
db_file = str(sparc_path / 'database' / db_id / f'{db_id}.sqlite')
database = SqliteDatabase(db_file, foreign_keys=sparc_tables[db_id].foreign_keys)
print(database.table_cols.keys())
database.execute('SELECT * FROM Department LIMIT 5;')

dict_keys(['Physician', 'Department', 'Affiliated_With', 'Procedures', 'Trained_In', 'Patient', 'Nurse', 'Appointment', 'Medication', 'Prescribes', 'Block', 'Room', 'On_Call', 'Stay', 'Undergoes'])


Unnamed: 0,DepartmentID,Name,Head
0,1,General Medicine,4
1,2,Surgery,7
2,3,Psychiatry,9


## Spider Dataset

In [4]:
spider_path = proj_path / 'data' / 'spider'
tables, train_data, dev_data = load_sparc_data(spider_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

with (proj_path / 'db_data' / 'description.json').open() as f:
    all_descriptions = json.load(f)
spider_tables = process_all_tables(tables, descriptions=all_descriptions)

all_data = filter_samples_by_count_spider(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
spider_samples = process_samples_spider(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(spider_samples, ratio=0.8)

Number of train: 7000 | Number of dev: 1034


## Workload Analysis

In [23]:
from src.sparc_preprocess import SparcSample, QuestionSQL

def format_interactions(interactions: list[QuestionSQL]) -> str:
    workload = ''
    for i, interaction in enumerate(interactions):
        workload += f'[{i}-Question] {interaction.question}\n[{i}-SQL]: {interaction.sql}\n'
    return workload.strip()

with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

idx = 1000
data = train_samples[idx]
table = sparc_tables[data.db_id]
col_explanation = all_descriptions[data.db_id]
# create schema string
schema_str = get_schema_str(
    schema=table.db_schema, 
    foreign_keys=table.foreign_keys,
    primary_keys=table.primary_keys,
    col_explanation=col_explanation
)
database = SqliteDatabase(str(sparc_path / 'database' / data.db_id / f'{data.db_id}.sqlite'), foreign_keys=table.foreign_keys)
workload = format_interactions(data.interactions)
print(workload, '\n')
print(f'[Final]\nQuestion: {data.final.question}\nSQL: {data.final.sql}\n')

[0-Question] How many games are there?
[0-SQL]: SELECT COUNT(*) FROM basketball_match
[1-Question] What is the lowest acc percent among the competitions?
[1-SQL]: SELECT MIN(acc_percent) FROM basketball_match
[2-Question] Can you order the schools by acc percent in descending order?
[2-SQL]: SELECT Team_Name FROM basketball_match ORDER BY acc_percent DESC
[3-Question] What is the highest acc percent socre?
[3-SQL]: SELECT acc_percent FROM basketball_match ORDER BY acc_percent DESC LIMIT 1 

[Final]
Question: What is the highest acc percent score in the competition?
SQL: SELECT acc_percent FROM basketball_match ORDER BY acc_percent DESC LIMIT 1



In [24]:
database.execute(data.interactions[0].sql)

Unnamed: 0,COUNT(*)
0,4


In [25]:
database.execute(data.interactions[1].sql)

Unnamed: 0,MIN(acc_percent)
0,0.563


In [26]:
database.execute(data.interactions[2].sql)

Unnamed: 0,Team_Name
0,North Carolina
1,Duke
2,Clemson
3,Virginia Tech


In [27]:
database.execute(data.final.sql)

Unnamed: 0,ACC_Percent
0,0.875


### 1. Common table extraction

* find the common table used in the question-sql workloads: All joined tables

In [28]:
import sqlglot
import sqlglot.expressions as exp
from sqlglot.diff import Keep
from sqlglot.optimizer import optimize
from collections import Counter

def extract_table_expression(x: QuestionSQL, schema: dict) -> str:
    sql = optimize(sqlglot.parse_one(x.sql, read='sqlite'), schema=schema)
    tbls = [x.this.this for x in list(sql.find_all(exp.Table))]
    expression = ' '.join([x.sql() for x in sql.find_all(*[exp.From, exp.Join])])
    return ','.join(tbls), expression

def get_sources(data: SparcSample, schema: dict) -> list[tuple[str, list[str]]]:
    sources = []
    for x in data.interactions:
        tbls, expression = extract_table_expression(x, schema)
        sources.append({'question': x.question, 'table': tbls, 'expression': expression})
    return sources

db_id = 'hospital_1'
train_subsamples = list(filter(lambda x: x.db_id == db_id, train_samples))
dev_subsamples = list(filter(lambda x: x.db_id == db_id, dev_samples))
table = sparc_tables[db_id]
database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)

used_tables = Counter()
for data in train_subsamples:
    sources = get_sources(data, table.db_schema)
    used = [x['table'] for x in sources]
    used_tables.update(used)

print(f'# of train workloads: {len(train_subsamples)}')
print(f'# of used tables: {len(used_tables)}\n-----------------')
for k, v in used_tables.most_common():
    print(f'{k}: {v}')

# of train workloads: 25
# of used tables: 22
-----------------
appointment: 9
department: 7
physician: 6
stay: 6
physician,patient: 6
block,room: 4
room: 4
medication: 3
appointment,patient: 2
appointment,physician: 2
physician,affiliated_with,department: 2
nurse,appointment: 2
prescribes,medication: 2
physician,prescribes,medication: 2
department,physician: 1
physician,appointment,physician: 1
patient,appointment: 1
prescribes,physician: 1
patient,prescribes,physician: 1
stay,patient,prescribes: 1
stay,patient,prescribes,medication: 1
medication,prescribes: 1


In [29]:
for data in train_subsamples:
    sources = get_sources(data, table.db_schema)
    used = [x['table'] for x in sources]
    if any([',' in x for x in used]):
        break

workload = format_interactions(data.interactions)
print(workload, '\n')

[0-Question] How many employees does each department have?
[0-SQL]: SELECT count(departmentID) FROM department GROUP BY departmentID
[1-Question] Which department has the smallest number of employees?
[1-SQL]: SELECT * FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;
[2-Question] Tell me the name and position of the head of this department.
[2-SQL]: SELECT T2.name ,  T2.position FROM department AS T1 JOIN physician AS T2 ON T1.head  =  T2.EmployeeID GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1; 



In [30]:
print(sqlglot.transpile(data.interactions[2].sql, write="sqlite", identify=False, pretty=True)[0])
sql = optimize(sqlglot.parse_one(data.interactions[2].sql, read='sqlite'), schema=table.db_schema)

SELECT
  T2.name,
  T2.position
FROM department AS T1
JOIN physician AS T2
  ON T1.head = T2.EmployeeID
GROUP BY
  departmentID
ORDER BY
  COUNT(departmentID)
LIMIT 1


### 2. Extract Term - Expression

In [31]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

In [34]:
class TermExpressions(BaseModel):
    rationale: str = Field(description='The reasoning behind the decision.')
    term: str = Field(description='A declarative form of the natural language term.')
    expression: str = Field(description='SQL expression that refers to the term.')

class Response(BaseModel):
    output: list[TermExpressions]

template = '''### Task
You are tasked with identifying which SQL expression(except FROM and JOIN clause) is related to the knowledge term of the natural language question.
You will be proveded a pair of question and SQL.
The term should be declarative form(e.g., "how many students" -> "the number of students").
There could be multiple terms that related to the SQL expression, vice versa.

### Formatting
Your output should be of the following list of JSON format:
[{{
    "rationale": "<str: the reasoning behind decision>",
    "term": "<str: a declarative form of natural language term>",
    "expression" : "<str: SQL expression that refer to the term>",
}}, ...]

### Output
<QUESTION>: {question}
<SQL>: {sql}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['question', 'sql']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.3,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_term_expression = defaultdict(list)
for data in tqdm(train_subsamples, total=len(train_subsamples)):
    for x in data.interactions:
        input_data = {'question': x.question, 'sql': x.sql}
        term_expression = chain.invoke(input=input_data).output
        tbls, _ = extract_table_expression(x, table.db_schema)
        all_term_expression[tbls].append(term_expression)

  0%|          | 0/25 [00:00<?, ?it/s]

100%|██████████| 25/25 [01:17<00:00,  3.12s/it]


In [35]:
all_term_expression

defaultdict(list,
            {'department': [[TermExpressions(rationale='The SQL expression uses the COUNT function to aggregate the number of employees grouped by departmentID, which directly relates to the question asking for the number of employees in each department.', term='the number of employees in each department', expression='count(departmentID)')],
              [TermExpressions(rationale='The SQL expression counts the number of employees in each department and orders them to find the one with the most employees.', term='the department with the most employees', expression='count(departmentID)')],
              [TermExpressions(rationale="The SQL expression uses the COUNT function to aggregate the number of employees in each department, which directly relates to the term 'the number of employees in each department'.", term='the number of employees in each department', expression='count(departmentID)')],
              [TermExpressions(rationale="The SQL expression uses 'count(

In [13]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(
    schema=sparc_tables['hospital_1'].db_schema, 
    col_explanation=all_descriptions['hospital_1'])[:300]
)

[Table and Columns]
Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'Depart


# Query Access Area

In [18]:
dtype_functions = {
    'numeric': pd.to_numeric,
    'datetime': pd.to_datetime
}

def null_percentage(s: pd.Series) -> float:
    return s.isnull().sum() / len(s)

column_info = {}
for col in df.columns:
    # dtype
    null_index = df[col].isnull()
    for logical_type in ['numeric', 'datetime', 'text']:
        if logical_type in ['numeric', 'datetime']:
            try:
                df.loc[~null_index, col] = dtype_functions[logical_type](df.loc[~null_index, col], errors='raise')
                attribute_type = 'ordinal'
                break
            except ValueError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
            except TypeError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
        else:
            attribute_type = 'nominal'
            break
    print(f'{col}: {logical_type} {attribute_type}')
    # unique values
    unique_values = df[col].unique()
    # min, max
    min_val = df[col].min()
    max_val = df[col].max()
    # null percentage
    null_percent = null_percentage(df[col])

    column_info[col] = {
        'logical_type': logical_type,
        'attribute_type': attribute_type,
        'unique_values': unique_values,
        'min': min_val,
        'max': max_val,
        'null_percentage': null_percent
    }

[(13216584, 100000001, 101, 1, '2008-04-24 10:00', '2008-04-24 11:00', 'A')]