## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str
from src.database import SqliteDatabase
from src.spider_sparc_preprocess import (
    load_spider_sparc_data,
    process_all_tables, 
    load_samples_spider,
    load_samples_sparc,
    filter_samples_by_count_sparc,
    filter_samples_by_count_spider, 
    process_samples_sparc,
    process_samples_spider, 
    split_train_dev_test,
    save_samples_spider
)

proj_path = Path('.').resolve()

## Spider Dataset

In [2]:
# # spider dataset
# spider_path = proj_path / 'data' / 'spider'
# tables, train_data, dev_data = load_spider_sparc_data(spider_path)

# with (proj_path / 'data' / 'description.json').open() as f:
#     all_descriptions = json.load(f)
# spider_tables = process_all_tables(tables, descriptions=all_descriptions)

# all_data = filter_samples_by_count_spider(train_data+dev_data, n=10)
# # process samples -> {db_id: list of samples}
# # skip = [3146, 4690, 4691]
# spider_samples = process_samples_spider(all_data, spider_tables, skip=[])
# # change train/dev by sample
# train_samples, dev_samples, test_samples = split_train_dev_test(spider_samples, train_ratio=0.8, dev_ratio=0.1)
# print(f'Number of train: {len(train_samples)} | Number of dev: {len(dev_samples)} | Number of test: {len(test_samples)}')

# save_samples_spider(train_samples, proj_path / 'data' / 'spider_train.json')
# save_samples_spider(dev_samples, proj_path / 'data' / 'spider_dev.json')
# save_samples_spider(test_samples, proj_path / 'data' / 'spider_test.json')


100%|██████████| 8023/8023 [00:01<00:00, 4236.60it/s]

Number of train: 6369 | Number of dev: 747 | Number of test: 907





In [3]:
with (proj_path / 'data' / 'spider' / f'tables.json').open() as f:
    tables = json.load(f)

with (proj_path / 'data' / 'description.json').open() as f:
    all_descriptions = json.load(f)
spider_tables = process_all_tables(tables, descriptions=all_descriptions)

train_samples = load_samples_spider(proj_path / 'data' / 'spider_train.json')
dev_samples = load_samples_spider(proj_path / 'data' / 'spider_dev.json')
test_samples = load_samples_spider(proj_path / 'data' / 'spider_test.json')
print(f'Number of train: {len(train_samples)} | Number of dev: {len(dev_samples)} | Number of test: {len(test_samples)}')

Number of train: 6369 | Number of dev: 747 | Number of test: 907


In [4]:
from collections import defaultdict, Counter

# how many multiple tables joined for each database?
counter = defaultdict(Counter)
for s in train_samples:
    counter[s.db_id].update([len(s.final.source_tables)])

counter['hospital_1']

Counter({1: 32, 2: 30, 3: 16, 4: 2})

# Common Interest Detection

## Augment Dataset for training a cross-encoder

* if two questions share the same source table, they are considered as a co-related pair
* using jaccard similarity to label the common interest: 
    * e.g., $q_1$ has three tables $t_1, t_2, t_3$, $q_2$ has two tables $t_1, t_2$, then the jaccard similarity is $2/3$

In [5]:
from collections import defaultdict
from itertools import groupby, combinations, product
from src.spider_sparc_preprocess import SpiderSample, SparcSample

def jaccard_similarity(i_tables: str|set, j_tables: str|set):
    def preprocess(tables: str):
        return set([t.strip() for t in tables.split(',')])
    # Get the number of common tables
    i_set = preprocess(i_tables) if isinstance(i_tables, str) else i_tables
    j_set = preprocess(j_tables) if isinstance(j_tables, str) else j_tables

    common_tables = i_set.intersection(j_set)
    union_tables = i_set.union(j_set)
    return len(common_tables) / len(union_tables)

def curate_samples(samples: list) -> list[dict]:

    dataset = []
    for db_id, group_samples in groupby(samples, key=lambda x: x.db_id):
        # schema_str = get_schema_str(spider_tables[db_id].db_schema, col_fmt='', skip_type=True, remove_meta=True)
        data_dict = defaultdict(list)
        for tbls, samples in groupby(group_samples, key=lambda x: x.final.source_tables):
            tbls = ', '.join(tbls)
            for s in samples:
                data_dict[tbls].append(s.final.question)
        
        for i_tables, j_tables in combinations(data_dict.keys(), 2):
            similarity = jaccard_similarity(i_tables, j_tables)
            i_data = data_dict[i_tables]
            j_data = data_dict[j_tables]
            for i, j in product(i_data, j_data):
                dataset.append(
                    {
                        'db_id': db_id,
                        'sentence1': i,
                        'sentence2': j,
                        'label': similarity,
                        'tables1': i_tables,
                        'tables2': j_tables
                    }
                )
    return dataset

train_dataset = curate_samples(train_samples)
dev_dataset = curate_samples(dev_samples)
test_dataset = curate_samples(test_samples)

with (proj_path / 'data' / 'spider_common_interest_train.json').open('w') as f:
    json.dump(train_dataset, f, indent=4)

with (proj_path / 'data' / 'spider_common_interest_dev.json').open('w') as f:
    json.dump(dev_dataset, f, indent=4)

with (proj_path / 'data' / 'spider_common_interest_test.json').open('w') as f:
    json.dump(test_dataset, f, indent=4)

In [6]:
len(train_dataset), len(dev_dataset), len(test_dataset)

(137797, 1278, 1800)

## Prepare Dataset

In [8]:
from datasets import load_dataset

ds = load_dataset('json', 
    data_files={'train': str(proj_path / 'data' / 'spider_common_interest_train.json'), 
                'validation': str(proj_path / 'data' / 'spider_common_interest_dev.json'),
                'test': str(proj_path / 'data' / 'spider_common_interest_test.json')})

Generating train split: 0 examples [00:00, ? examples/s]

DatasetGenerationError: An error occurred while generating the dataset

In [None]:
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments


def load_dataset(path: Path):
    with path.open('r') as f:
        dataset = json.load(f)
        
    
    
    return dataset

with (proj_path / 'data' / 'spider_common_interest_train.json').open() as f:
    train_dataset = json.load(f)

with (proj_path / 'data' / 'spider_common_interest_dev.json').open() as f:
    dev_dataset = json.load(f)


model_name = 'all-MiniLM-L6-v2'
train_batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
max_seq_length = 75
num_epochs = 1

model = SentenceTransformer(model_name)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
loss = losses.CosineSimilarityLoss()


In [3]:
db_id = 'hospital_1'
train_subsamples = list(filter(lambda x: x.db_id == db_id, train_samples))
dev_subsamples = list(filter(lambda x: x.db_id == db_id, dev_samples))

s = get_schema_str(spider_tables[db_id].db_schema, col_fmt='', skip_type=True)

In [4]:
from itertools import groupby

for k, g in groupby(train_subsamples, key=lambda x: x.final.source_tables):
    if len(k) > 1:
        print(k, len(list(g)))


print()
for k, g in groupby(dev_subsamples, key=lambda x: x.final.source_tables):
    if len(k) > 1:
        print(k, len(list(g)))

# TODO: remove alias from table names

['department', 'physician'] 2
['patient', 'appointment'] 4
['appointment', 'physician'] 4
['department', 'affiliated_with', 'physician'] 2
['patient', 'appointment'] 2
['patient', 'physician', 'prescribes'] 2
['stay', 'patient', 'medication', 'prescribes'] 2
['nurse', 'appointment'] 2
['patient', 'physician'] 4
['block', 'room'] 4
['medication', 'prescribes', 'physician'] 4
['medication', 'prescribes'] 2
['stay', 'patient', 'undergoes'] 2
['nurse', 'undergoes'] 2
['prescribes', 'physician'] 2
['department', 'affiliated_with'] 2
['procedures', 'trained_in', 'physician'] 6

['procedures', 'trained_in', 'physician'] 2
['trained_in', 'procedures', 'physician'] 6
['department', 'affiliated_with', 'physician'] 4
['patient', 'medication', 'prescribes'] 4
['nurse', 'on_call'] 2


Loss: 

* MultipleNegativesRankingLoss

In [None]:
# https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli_v3.py
# https://www.sbert.net/docs/package_reference/sentence_transformer/losses.html#sentence_transformers.losses.MultipleNegativesRankingLoss
import tqdm
from itertools import groupby, combinations
from src.spider_sparc_preprocess import SpiderSample, SparcSample
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import BinaryClassificationEvaluator

def curate_dataset(samples: list[SpiderSample|SparcSample]):
    dataset = []
    for db_id, group_samples in groupby(samples, key=lambda x: x.db_id):
        schema_str = get_schema_str(spider_tables[db_id].db_schema, col_fmt='', skip_type=True, remove_meta=True)
        # positive pairs
        for tbls, samples in groupby(group_samples, key=lambda x: x.final.source_tables):
            questions = [schema_str + '\n' + s.final.question for s in samples]
            pairs = list(combinations(questions, 2))
            for p in pairs:
                dataset.append(InputExample(texts=p, label=1))
        # negative pairs

    return dataset

train_data = curate_dataset(train_samples)
dev_data = curate_dataset(dev_samples)

model_name = 'all-MiniLM-L6-v2'
train_batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
max_seq_length = 75
num_epochs = 1

model = SentenceTransformer(model_name)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)

# Use the denoising auto-encoder loss
train_loss = losses.MultipleNegativesRankingLoss(model)

# model.fit(
#     train_objectives=[(train_dataloader, train_loss)], epochs=1, show_progress_bar=True
# )

In [62]:
data

[{'db_id': 'department_management',
  'sentence1': 'How many heads of the departments are older than 56 ?',
  'sentence2': 'List the creation year, name and budget of each department.',
  'label': 0.0,
  'tables1': 'head',
  'tables2': 'department'},
 {'db_id': 'department_management',
  'sentence1': 'How many heads of the departments are older than 56 ?',
  'sentence2': 'What are the maximum and minimum budget of the departments?',
  'label': 0.0,
  'tables1': 'head',
  'tables2': 'department'},
 {'db_id': 'department_management',
  'sentence1': 'How many heads of the departments are older than 56 ?',
  'sentence2': 'What is the average number of employees of the departments whose rank is between 10 and 15?',
  'label': 0.0,
  'tables1': 'head',
  'tables2': 'department'},
 {'db_id': 'department_management',
  'sentence1': 'How many heads of the departments are older than 56 ?',
  'sentence2': 'In which year were most departments established?',
  'label': 0.0,
  'tables1': 'head',
  '

In [57]:
j_data

['How many acting statuses are there?']

In [56]:
list(product(i_data, j_data))

[("Show the name and number of employees for the departments managed by heads whose temporary acting value is 'Yes'?",
  'How many acting statuses are there?'),
 ('How many departments are led by heads who are not mentioned?',
  'How many acting statuses are there?')]

In [46]:
# algorithm
# if table in in the joined tables, they are related

['How many heads of the departments are older than 56 ?',
 'List the name, born state and age of the heads of departments ordered by age.',
 'What are the names of the heads who are born outside the California state?',
 'What are the names of the states where at least 3 heads were born?']

In [52]:
def jaccard_similarity(i_tables: str|set, j_tables: str|set):
    def preprocess(tables: str):
        return set([t.strip() for t in tables.split(',')])
    # Get the number of common tables
    i_set = preprocess(i_tables) if isinstance(i_tables, str) else i_tables
    j_set = preprocess(j_tables) if isinstance(j_tables, str) else j_tables

    common_tables = i_set.intersection(j_set)
    union_tables = i_set.union(j_set)
    return len(common_tables) / len(union_tables)

# Calculate similarity between all pairs of tables
table_pairs = []
for i_tables, j_tables in combinations(data_dict.keys(), 2):
    similarity = jaccard_similarity(i_tables, j_tables)
    table_pairs.append((i_tables, j_tables, similarity))

print(table_pairs)

[('head', 'department', 0.0), ('head', 'department, management, head', 0.3333333333333333), ('head', 'department, management', 0.0), ('head', 'management', 0.0), ('department', 'department, management, head', 0.3333333333333333), ('department', 'department, management', 0.5), ('department', 'management', 0.0), ('department, management, head', 'department, management', 0.6666666666666666), ('department, management, head', 'management', 0.3333333333333333), ('department, management', 'management', 0.5)]


In [17]:
binary_acc_evaluator = BinaryClassificationEvaluator.from_input_examples(dev_data, name='dev')
results = binary_acc_evaluator(model)

In [18]:
results

{'dev_cosine_accuracy': 0.9995553579368608,
 'dev_cosine_accuracy_threshold': 0.5259714126586914,
 'dev_cosine_f1': 0.9997776295307983,
 'dev_cosine_f1_threshold': 0.5259714126586914,
 'dev_cosine_precision': 1.0,
 'dev_cosine_recall': 0.9995553579368608,
 'dev_cosine_ap': 1.0,
 'dev_dot_accuracy': 0.9995553579368608,
 'dev_dot_accuracy_threshold': 0.5259714722633362,
 'dev_dot_f1': 0.9997776295307983,
 'dev_dot_f1_threshold': 0.5259714722633362,
 'dev_dot_precision': 1.0,
 'dev_dot_recall': 0.9995553579368608,
 'dev_dot_ap': 1.0,
 'dev_manhattan_accuracy': 0.9995553579368608,
 'dev_manhattan_accuracy_threshold': 15.076079368591309,
 'dev_manhattan_f1': 0.9997776295307983,
 'dev_manhattan_f1_threshold': 15.076079368591309,
 'dev_manhattan_precision': 1.0,
 'dev_manhattan_recall': 0.9995553579368608,
 'dev_manhattan_ap': 1.0,
 'dev_euclidean_accuracy': 0.9995553579368608,
 'dev_euclidean_accuracy_threshold': 0.9736785888671875,
 'dev_euclidean_f1': 0.9997776295307983,
 'dev_euclidean_f1

In [8]:
import sqlglot
import sqlglot.expressions as exp
from sqlglot.optimizer import optimize
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema

RULES = (
    # qualify,
    # pushdown_projections,
    # normalize,
    # unnest_subqueries,
    # pushdown_predicates,
    # optimize_joins,
    # eliminate_subqueries,
    # merge_subqueries,
    # eliminate_joins,
    # eliminate_ctes,
    # quote_identifiers,
    # annotate_types,
    # canonicalize,
    # simplify,
)
{"sample_id": 1272, "db_id": "apartment_rentals", "final": 
 {"question": "Show the apartment numbers of apartments with unit status availability of both 0 and 1.", 
"sql": "SELECT T1.apt_number FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id  =  T2.apt_id WHERE T2.available_yn  =  0 INTERSECT SELECT T1.apt_number FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id  =  T2.apt_id WHERE T2.available_yn  =  1", "source_tables": ["t1", "view_unit_status", "t1", "view_unit_status", "apartments"]}}

db_id = 'apartment_rentals'
schema = spider_tables[db_id].db_schema
sql = "SELECT T1.apt_number FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id  =  T2.apt_id WHERE T2.available_yn  =  0 INTERSECT SELECT T1.apt_number FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id  =  T2.apt_id WHERE T2.available_yn  =  1"
sql = sqlglot.parse_one(sql, read='sqlite')
print(sql.sql(pretty=True))
tbls = [x.this.this.lower() for x in list(sql.find_all(exp.Table))]

expression = ' '.join([x.sql() for x in sql.find_all(*[exp.From, exp.Join, exp.Where])])
print(tbls)
print(expression)

SELECT
  T1.apt_number
FROM Apartments AS T1
JOIN View_Unit_Status AS T2
  ON T1.apt_id = T2.apt_id
WHERE
  T2.available_yn = 0
INTERSECT
SELECT
  T1.apt_number
FROM Apartments AS T1
JOIN View_Unit_Status AS T2
  ON T1.apt_id = T2.apt_id
WHERE
  T2.available_yn = 1
['apartments', 'view_unit_status', 'apartments', 'view_unit_status']
FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id = T2.apt_id WHERE T2.available_yn = 0 FROM Apartments AS T1 JOIN View_Unit_Status AS T2 ON T1.apt_id = T2.apt_id WHERE T2.available_yn = 1


In [6]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-deberta-v3-large')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-deberta-v3-large')

features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'],  padding=True, truncation=True, return_tensors="pt")

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    label_mapping = ['contradiction', 'entailment', 'neutral']
    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
    print(labels)


[Table and Columns]
Physician: EmployeeID, Name, Position, SSN
Department: DepartmentID, Name, Head
Affiliated_With: Physician, Department, PrimaryAffiliation
Procedures: Code, Name, Cost
Trained_In: Physician, Treatment, CertificationDate, CertificationExpires
Patient: SSN, Name, Address, Phone, InsuranceID, PCP
Nurse: EmployeeID, Name, Position, Registered, SSN
Appointment: AppointmentID, Patient, PrepNurse, Physician, Start, End, ExaminationRoom
Medication: Code, Name, Brand, Description
Prescribes: Physician, Patient, Medication, Date, Appointment, Dose
Block: BlockFloor, BlockCode
Room: RoomNumber, RoomType, BlockFloor, BlockCode, Unavailable
On_Call: Nurse, BlockFloor, BlockCode, OnCallStart, OnCallEnd
Stay: StayID, Patient, Room, StayStart, StayEnd
Undergoes: Patient, Procedures, Stay, DateUndergoes, Physician, AssistingNurse


In [10]:
spider_tables[train_samples[0].db_id].db_schema

DatabaseModel(db_id='department_management', db_schema={'department': {'Department_ID': 'text', 'Name': 'text', 'Creation': 'text', 'Ranking': 'text', 'Budget_in_Billions': 'text', 'Num_Employees': 'text'}, 'head': {'head_ID': 'number', 'name': 'number', 'born_state': 'number', 'age': 'number'}, 'management': {'department_ID': 'text', 'head_ID': 'text', 'temporary_acting': 'text'}}, col_explanation={'department': {'Department_ID': 'Unique identifier for each department.', 'Name': 'Name of the department.', 'Creation': 'Date when the department was established.', 'Ranking': 'Ranking of the department based on performance.', 'Budget_in_Billions': 'Annual budget allocated to the department in billions.', 'Num_Employees': 'Total number of employees in the department.'}, 'head': {'head_ID': 'Unique identifier for each department head.', 'name': 'Name of the department head.', 'born_state': 'State where the department head was born.', 'age': 'Age of the department head.'}, 'management': {'de

In [11]:
import sqlglot
import sqlglot.expressions as exp
from sqlglot.optimizer import optimize

def extract_used_table(sql: str, schema: dict) -> list[str]:
    sql = optimize(sqlglot.parse_one(sql, read='sqlite'), schema=schema)
    tbls = [x.this.this for x in list(sql.find_all(exp.Table))]
    return tbls

extract_used_table(train_samples[0].final.sql, spider_tables[train_samples[0].db_id].db_schema)

['head']

# Sparc Dataset

In [None]:
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_spider_sparc_data(sparc_path)

with (proj_path / 'db_data' / 'description.json').open() as f:
    all_descriptions = json.load(f)

sparc_tables = process_all_tables(tables, descriptions=all_descriptions)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count_sparc(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples_sparc(all_data, sparc_tables)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

print(f'Number of train: {len(train_samples)} | Number of dev: {len(dev_samples)}')

db_id = 'hospital_1'
db_file = str(sparc_path / 'database' / db_id / f'{db_id}.sqlite')
database = SqliteDatabase(db_file, foreign_keys=sparc_tables[db_id].foreign_keys)
print(database.table_cols.keys())
database.execute('SELECT * FROM Department LIMIT 5;')

## Workload Analysis

In [7]:
from src.spider_sparc_preprocess import SparcSample, QuestionSQL

def format_interactions(interactions: list[QuestionSQL]) -> str:
    workload = ''
    for i, interaction in enumerate(interactions):
        workload += f'[{i}-Question] {interaction.question}\n[{i}-SQL]: {interaction.sql}\n'
    return workload.strip()

with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

db_id = 'hospital_1'
train_subsamples = list(filter(lambda x: x.db_id == db_id, train_samples))
dev_subsamples = list(filter(lambda x: x.db_id == db_id, dev_samples))

table = sparc_tables[db_id]
col_explanation = all_descriptions[db_id]
# create schema string
schema_str = get_schema_str(
    schema=table.db_schema, 
    foreign_keys=table.foreign_keys,
    primary_keys=table.primary_keys,
    col_explanation=col_explanation
)  
database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)

data = train_samples[2]
workload = format_interactions(data.interactions)
print(workload, '\n')
print(f'[Final]\nQuestion: {data.final.question}\nSQL: {data.final.sql}\n')

[0-Question] How many employees does each department have?
[0-SQL]: SELECT count(departmentID) FROM department GROUP BY departmentID
[1-Question] Which department has the smallest number of employees?
[1-SQL]: SELECT * FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;
[2-Question] Tell me the name and position of the head of this department.
[2-SQL]: SELECT T2.name ,  T2.position FROM department AS T1 JOIN physician AS T2 ON T1.head  =  T2.EmployeeID GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1; 

[Final]
Question: Find the name and position of the head of the department with the least employees.
SQL: SELECT T2.name ,  T2.position FROM department AS T1 JOIN physician AS T2 ON T1.head  =  T2.EmployeeID GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;



In [8]:
database.execute(data.interactions[0].sql)

Unnamed: 0,count(departmentID)
0,1
1,1
2,1


In [9]:
database.execute(data.interactions[1].sql)

Unnamed: 0,DepartmentID,Name,Head
0,1,General Medicine,4


In [10]:
database.execute(data.interactions[2].sql)

Unnamed: 0,Name,Position
0,Percival Cox,Senior Attending Physician


In [11]:
database.execute(data.final.sql)

Unnamed: 0,Name,Position
0,Percival Cox,Senior Attending Physician


### 1. Common table extraction

* find the common table used in the question-sql workloads: All joined tables

In [12]:
import sqlglot
import sqlglot.expressions as exp
from sqlglot.diff import Keep
from sqlglot.optimizer import optimize
from collections import Counter

def extract_table_expression(x: QuestionSQL, schema: dict) -> str:
    sql = optimize(sqlglot.parse_one(x.sql, read='sqlite'), schema=schema)
    tbls = [x.this.this for x in list(sql.find_all(exp.Table))]
    expression = ' '.join([x.sql() for x in sql.find_all(*[exp.From, exp.Join])])
    return ','.join(tbls), expression

def get_sources(data: SparcSample, schema: dict) -> list[tuple[str, list[str]]]:
    sources = []
    for x in data.interactions:
        tbls, expression = extract_table_expression(x, schema)
        sources.append({'question': x.question, 'table': tbls, 'expression': expression})
    return sources

db_id = 'hospital_1'
train_subsamples = list(filter(lambda x: x.db_id == db_id, train_samples))
dev_subsamples = list(filter(lambda x: x.db_id == db_id, dev_samples))
table = sparc_tables[db_id]
database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)

used_tables = Counter()
for data in train_subsamples:
    sources = get_sources(data, table.db_schema)
    used = [x['table'] for x in sources]
    used_tables.update(used)

print(f'# of train workloads: {len(train_subsamples)}')
print(f'# of used tables: {len(used_tables)}\n-----------------')
for k, v in used_tables.most_common():
    print(f'{k}: {v}')

# of train workloads: 25
# of used tables: 22
-----------------
appointment: 9
department: 7
physician: 6
stay: 6
physician,patient: 6
block,room: 4
room: 4
medication: 3
appointment,patient: 2
appointment,physician: 2
physician,affiliated_with,department: 2
nurse,appointment: 2
prescribes,medication: 2
physician,prescribes,medication: 2
department,physician: 1
physician,appointment,physician: 1
patient,appointment: 1
prescribes,physician: 1
patient,prescribes,physician: 1
stay,patient,prescribes: 1
stay,patient,prescribes,medication: 1
medication,prescribes: 1


In [21]:
# a nested query
sql = """
SELECT * FROM (
    SELECT * FROM Department
) AS A
WHERE A.department_id = 1;
"""

sql = sqlglot.parse_one(sql, read='sqlite')
tbls = [x.this.this for x in list(sql.find_all(exp.Table))]
expression = ' '.join([x.sql() for x in sql.find_all(*[exp.From, exp.Join])])

In [22]:
tbls

['Department']

In [39]:
for i, data in enumerate(train_subsamples):
    print(f'[{i:02d} - Question] {data.final.question}')
    print(f'[{i:02d} - SQL] {data.final.sql}')

[00 - Question] Find the department with the most employees.
[00 - SQL] SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;
[01 - Question] Tell me the employee id of the head of the department with the least employees.
[01 - SQL] SELECT head FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;
[02 - Question] Find the name and position of the head of the department with the least employees.
[02 - SQL] SELECT T2.name ,  T2.position FROM department AS T1 JOIN physician AS T2 ON T1.head  =  T2.EmployeeID GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;
[03 - Question] List the names of patients who have made appointments.
[03 - SQL] SELECT name FROM appointment AS T1 JOIN patient AS T2 ON T1.patient  =  T2.ssn
[04 - Question] Which patients made more than one appointment? Tell me the name and phone number of these patients.
[04 - SQL] SELECT name ,  phone FROM appointment AS T1 JOIN patient AS T2 ON T1.patient  = 

In [17]:
data = train_samples[2]
print(sqlglot.transpile(data.interactions[2].sql, write="sqlite", identify=False, pretty=True)[0])
sql = optimize(sqlglot.parse_one(data.interactions[2].sql, read='sqlite'), schema=table.db_schema)

SELECT
  T2.name,
  T2.position
FROM department AS T1
JOIN physician AS T2
  ON T1.head = T2.EmployeeID
GROUP BY
  departmentID
ORDER BY
  COUNT(departmentID)
LIMIT 1


### 2. Extract Term - Expression

In [18]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

In [31]:
class TermExpressions(BaseModel):
    rationale: str = Field(description='The reasoning behind the decision.')
    index: int = Field(description='Index of the question-sql pair.')
    term: str = Field(description='A declarative form of the natural language term.')
    expression: str = Field(description='SQL expression that refers to the term.')

class Response(BaseModel):
    output: list[TermExpressions]
    
template = '''### Task
You are tasked with identifying the partial term - partial expression relationship to represent the common interest query.
You will be proveded several pairs of question and SQL with index. Do not extract the FROM and JOIN clauses.
There could be multiple terms and expressions in a single question-SQL pair.

### Formatting
Your output should be of the following list of JSON format:
[
    {{
        "rationale": <str: the reasoning behind decision>,
        "index: <int: the index of the question-sql pair>,
        "term": <str: a partial natural language term>,
        "expression" : <str: a partial SQL expression that refer to the term>,
    }}, ...
]


### Output
<QUESTION-SQL>:\n{workload}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['workload']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_term_expression = defaultdict(list)
for data in tqdm(train_subsamples, total=len(train_subsamples)):
    workload = format_interactions(data.interactions)
    term_expression = chain.invoke(input={'workload': workload}).output
    all_term_expression[data.sample_id] = term_expression
    # for x in data.interactions:
    #     input_data = {'question': x.question, 'sql': x.sql}
    #     term_expression = chain.invoke(input=input_data).output
    #     tbls, _ = extract_table_expression(x, table.db_schema)
    #     all_term_expression[tbls].append(term_expression)

  0%|          | 0/25 [00:00<?, ?it/s]

100%|██████████| 25/25 [01:21<00:00,  3.27s/it]


In [35]:
for sample_id, term_exps in all_term_expression.items():
    for x in term_exps:
        print(f'[Sample {sample_id:02d}-{x.index}]: {x.term} - {x.expression}')

[Sample 00-0]: employees - count(departmentID)
[Sample 00-0]: department - GROUP BY departmentID
[Sample 00-1]: most employees - ORDER BY count(departmentID) DESC LIMIT 1
[Sample 00-1]: department name - name
[Sample 00-1]: department - GROUP BY departmentID
[Sample 01-0]: employees - count(departmentID)
[Sample 01-0]: department - GROUP BY departmentID
[Sample 01-1]: least employees - ORDER BY count(departmentID) LIMIT 1
[Sample 01-1]: department - GROUP BY departmentID
[Sample 01-2]: head - head
[Sample 01-2]: department - GROUP BY departmentID
[Sample 02-0]: employees - count(departmentID)
[Sample 02-0]: department - GROUP BY departmentID
[Sample 02-1]: smallest number of employees - ORDER BY count(departmentID) LIMIT 1
[Sample 02-1]: department - GROUP BY departmentID
[Sample 02-2]: name - T2.name
[Sample 02-2]: position - T2.position
[Sample 02-2]: department - GROUP BY departmentID
[Sample 03-0]: patient id - patient
[Sample 03-1]: names of patients - name
[Sample 03-1]: appointm

In [13]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(
    schema=sparc_tables['hospital_1'].db_schema, 
    col_explanation=all_descriptions['hospital_1'])[:300]
)

[Table and Columns]
Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'Depart


# Query Access Area

In [18]:
dtype_functions = {
    'numeric': pd.to_numeric,
    'datetime': pd.to_datetime
}

def null_percentage(s: pd.Series) -> float:
    return s.isnull().sum() / len(s)

column_info = {}
for col in df.columns:
    # dtype
    null_index = df[col].isnull()
    for logical_type in ['numeric', 'datetime', 'text']:
        if logical_type in ['numeric', 'datetime']:
            try:
                df.loc[~null_index, col] = dtype_functions[logical_type](df.loc[~null_index, col], errors='raise')
                attribute_type = 'ordinal'
                break
            except ValueError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
            except TypeError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
        else:
            attribute_type = 'nominal'
            break
    print(f'{col}: {logical_type} {attribute_type}')
    # unique values
    unique_values = df[col].unique()
    # min, max
    min_val = df[col].min()
    max_val = df[col].max()
    # null percentage
    null_percent = null_percentage(df[col])

    column_info[col] = {
        'logical_type': logical_type,
        'attribute_type': attribute_type,
        'unique_values': unique_values,
        'min': min_val,
        'max': max_val,
        'null_percentage': null_percent
    }

[(13216584, 100000001, 101, 1, '2008-04-24 10:00', '2008-04-24 11:00', 'A')]