* gpt-4o-mini
* gemini-1.5-pro
* gemini-1.5-flash

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict
from src.database import SqliteDatabase, DuckDBDatabase
from src.sparc_preprocess import (
    load_sparc_data,
    process_all_tables, 
    filter_samples_by_count, 
    process_samples, 
    split_train_dev
)

# duckdb.sql('INSTALL sqlite')
# duckdb.sql('SET GLOBAL sqlite_all_varchar = true;')

proj_path = Path('.').resolve()
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

sparc_tables = process_all_tables(tables)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

Number of train: 3034 | Number of dev: 422


In [6]:
db_id = 'hospital_1'
db_file = str(sparc_path / 'database' / db_id / f'{db_id}.sqlite')
database = SqliteDatabase(db_file, foreign_keys=sparc_tables[db_id].foreign_keys)
database.table_cols.keys()

dict_keys(['Physician', 'Department', 'Affiliated_With', 'Procedures', 'Trained_In', 'Patient', 'Nurse', 'Appointment', 'Medication', 'Prescribes', 'Block', 'Room', 'On_Call', 'Stay', 'Undergoes'])

In [10]:
database.execute('SELECT * FROM Department LIMIT 5;')

Unnamed: 0,DepartmentID,Name,Head
0,1,General Medicine,4
1,2,Surgery,7
2,3,Psychiatry,9


In [71]:
dtype_functions = {
    'numeric': pd.to_numeric,
    'datetime': pd.to_datetime
}

def null_percentage(s: pd.Series) -> float:
    return s.isnull().sum() / len(s)

column_info = {}
for col in df.columns:
    # dtype
    null_index = df[col].isnull()
    for logical_type in ['numeric', 'datetime', 'text']:
        if logical_type in ['numeric', 'datetime']:
            try:
                df.loc[~null_index, col] = dtype_functions[logical_type](df.loc[~null_index, col], errors='raise')
                attribute_type = 'ordinal'
                break
            except ValueError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
            except TypeError as e:
                # print(f'-- {col}: {logical_type} {e}')
                continue
        else:
            attribute_type = 'nominal'
            break
    print(f'{col}: {logical_type} {attribute_type}')
    # unique values
    unique_values = df[col].unique()
    # min, max
    min_val = df[col].min()
    max_val = df[col].max()
    # null percentage
    null_percent = null_percentage(df[col])

    column_info[col] = {
        'logical_type': logical_type,
        'attribute_type': attribute_type,
        'unique_values': unique_values,
        'min': min_val,
        'max': max_val,
        'null_percentage': null_percent
    }

AppointmentID: numeric ordinal
Patient: numeric ordinal
PrepNurse: numeric ordinal
Physician: numeric ordinal
Start: datetime ordinal


## Augmentation

In [3]:
from dotenv import load_dotenv, find_dotenv

from langchain_core.pydantic_v1 import BaseModel as LCBaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

In [13]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(
    schema=sparc_tables['hospital_1'].db_schema, 
    col_explanation=all_descriptions['hospital_1'])[:300]
)

[Table and Columns]
Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'Depart


In [15]:
sparc_tables['hospital_1'].foreign_keys

['Department.Head = Physician.EmployeeID',
 'Affiliated_With.Department = Department.DepartmentID',
 'Affiliated_With.Physician = Physician.EmployeeID',
 'Trained_In.Treatment = Procedures.Code',
 'Trained_In.Physician = Physician.EmployeeID',
 'Patient.PCP = Physician.EmployeeID',
 'Appointment.Physician = Physician.EmployeeID',
 'Appointment.PrepNurse = Nurse.EmployeeID',
 'Appointment.Patient = Patient.SSN',
 'Prescribes.Appointment = Appointment.AppointmentID',
 'Prescribes.Medication = Medication.Code',
 'Prescribes.Patient = Patient.SSN',
 'Prescribes.Physician = Physician.EmployeeID',
 'Room.BlockFloor = Block.BlockFloor',
 'Room.BlockCode = Block.BlockCode',
 'On_Call.BlockFloor = Block.BlockFloor',
 'On_Call.BlockCode = Block.BlockCode',
 'On_Call.Nurse = Nurse.EmployeeID',
 'Stay.Room = Room.RoomNumber',
 'Stay.Patient = Patient.SSN',
 'Undergoes.AssistingNurse = Nurse.EmployeeID',
 'Undergoes.Physician = Physician.EmployeeID',
 'Undergoes.Stay = Stay.StayID',
 'Undergoes.Pro

# Query Access Area

In [82]:
from src.sparc_preprocess import SparcSample, QuestionSQL

def format_interactions(interactions: list[QuestionSQL]) -> str:
    workload = ''
    for i, interaction in enumerate(interactions):
        workload += f'[{i}-Question] {interaction.question}\n[{i}-SQL]: {interaction.sql}\n'
    return workload.strip()

with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

idx = 1
data = train_samples[idx]
table = sparc_tables[data.db_id]
col_explanation = all_descriptions[data.db_id]
schema_str = get_schema_str(
    schema=table.db_schema, 
    foreign_keys=table.foreign_keys,
    primary_keys=table.primary_keys,
    col_explanation=col_explanation
)
workload = format_interactions(data.interactions)
print(workload, '\n')
print(f'[Final]\nQuestion: {data.final.question}\nSQL: {data.final.sql}\n')

[0-Question] How many employees does each department have?
[0-SQL]: SELECT count(departmentID) FROM department GROUP BY departmentID
[1-Question] Which department has the least employees?
[1-SQL]: SELECT * FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;
[2-Question] Who is the head of this department? Find the employee id.
[2-SQL]: SELECT head FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1; 

[Final]
Question: Tell me the employee id of the head of the department with the least employees.
SQL: SELECT head FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;



In [84]:
import sqlglot
import sqlglot.expressions as exp
from sqlglot.diff import Keep
from sqlglot.optimizer import optimize

def get_sources(data: SparcSample, schema: dict) -> list[tuple[str, list[str]]]:
    sources = []
    for x in data.interactions:
        sql = optimize(sqlglot.parse_one(x.sql, read='sqlite'), schema=schema)
        tbls = [x.this.this for x in list(sql.find_all(exp.Table))]
        sources.append((x.question, tbls))
    return sources

db_id = 'hospital_1'
train_subsamples = list(filter(lambda x: x.db_id == db_id, train_samples))
dev_subsamples = list(filter(lambda x: x.db_id == db_id, dev_samples))
table = sparc_tables[db_id]
database = Database(db_file=str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'))

# train_sources = []
# for data in train_subsamples:
#     train_sources.append(get_sources(data, schema=table.db_schema))
# dev_sources = []
# for data in dev_subsamples:
#     dev_sources.append(get_sources(data, schema=table.db_schema))

In [153]:
x = train_subsamples[4].interactions
schema = table.db_schema
sql1 = sqlglot.parse_one(x[0].sql, read='sqlite')
sql2 = sqlglot.parse_one(x[1].sql, read='sqlite')
print(sql1.sql())
print(sql2.sql())

SELECT COUNT(*) FROM appointment GROUP BY patient
SELECT * FROM appointment GROUP BY patient HAVING COUNT(*) > 1


In [155]:
import sqlite3

conn = sqlite3.connect(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'))
cursor = conn.cursor()


In [160]:
pd.read_sql_query('SELECT name FROM appointment AS T1 JOIN patient AS T2 ON T1.patient = T2.ssn', conn)

Unnamed: 0,Name
0,John Smith
1,Grace Ritchie
2,John Smith
3,Dennis Doe
4,Dennis Doe
5,Random J. Patient
6,John Smith
7,Dennis Doe
8,Grace Ritchie


In [157]:
cursor.execute('SELECT name FROM appointment AS T1 JOIN patient AS T2 ON T1.patient = T2.ssn').fetchall()

[('John Smith',),
 ('Grace Ritchie',),
 ('John Smith',),
 ('Dennis Doe',),
 ('Dennis Doe',),
 ('Random J. Patient',),
 ('John Smith',),
 ('Dennis Doe',),
 ('Grace Ritchie',)]

In [159]:
cursor.execute('SELECT * FROM appointment GROUP BY patient HAVING COUNT(*) > 1').fetchall()

[(13216584, 100000001, 101, 1, '2008-04-24 10:00', '2008-04-24 11:00', 'A'),
 (26548913, 100000002, 101, 2, '2008-04-24 10:00', '2008-04-24 11:00', 'B'),
 (46846589, 100000004, 103, 4, '2008-04-25 10:00', '2008-04-25 11:00', 'B')]

In [120]:
diff1 = sqlglot.diff(sql1, sql2)
diff2 = sqlglot.diff(sql2, sql1)
diff1 = set(filter(lambda x: isinstance(x, Keep), diff1))
diff2 = set(filter(lambda x: isinstance(x, Keep), diff2))
filter(lambda x: type(x.source) in [exp.Table, exp.Group, exp.Join, exp] , diff1.intersection(diff2))

{Keep(source=Column(
   this=Identifier(this=departmentID, quoted=False)), target=Column(
   this=Identifier(this=departmentID, quoted=False))),
 Keep(source=Count(
   this=Column(
     this=Identifier(this=departmentID, quoted=False)),
   big_int=True), target=Count(
   this=Column(
     this=Identifier(this=departmentID, quoted=False)),
   big_int=True)),
 Keep(source=From(
   this=Table(
     this=Identifier(this=department, quoted=False))), target=From(
   this=Table(
     this=Identifier(this=department, quoted=False)))),
 Keep(source=Group(
   expressions=[
     Column(
       this=Identifier(this=departmentID, quoted=False))]), target=Group(
   expressions=[
     Column(
       this=Identifier(this=departmentID, quoted=False))])),
 Keep(source=Table(
   this=Identifier(this=department, quoted=False)), target=Table(
   this=Identifier(this=department, quoted=False)))}

In [126]:
x = train_subsamples[1].interactions
schema = table.db_schema
sql1 = sqlglot.parse_one(x[0].sql, read='sqlite')
sql2 = sqlglot.parse_one(x[1].sql, read='sqlite')
print(sql1.sql())
print(sql2.sql())

SELECT COUNT(departmentID) FROM department GROUP BY departmentID
SELECT * FROM department GROUP BY departmentID ORDER BY COUNT(departmentID) LIMIT 1


In [127]:
diff1 = sqlglot.diff(sql1, sql2)
diff2 = sqlglot.diff(sql2, sql1)
diff1 = set(filter(lambda x: isinstance(x, Keep), diff1))
diff2 = set(filter(lambda x: isinstance(x, Keep), diff2))
diff1.intersection(diff2)

{Keep(source=Column(
   this=Identifier(this=departmentID, quoted=False)), target=Column(
   this=Identifier(this=departmentID, quoted=False))),
 Keep(source=Count(
   this=Column(
     this=Identifier(this=departmentID, quoted=False)),
   big_int=True), target=Count(
   this=Column(
     this=Identifier(this=departmentID, quoted=False)),
   big_int=True)),
 Keep(source=From(
   this=Table(
     this=Identifier(this=department, quoted=False))), target=From(
   this=Table(
     this=Identifier(this=department, quoted=False)))),
 Keep(source=Group(
   expressions=[
     Column(
       this=Identifier(this=departmentID, quoted=False))]), target=Group(
   expressions=[
     Column(
       this=Identifier(this=departmentID, quoted=False))])),
 Keep(source=Table(
   this=Identifier(this=department, quoted=False)), target=Table(
   this=Identifier(this=department, quoted=False)))}

In [24]:
from tqdm import tqdm
from duckdb import ConversionException

errors = []
for db_id in tqdm(sparc_tables.keys(), total=len(sparc_tables)):
    table = sparc_tables[db_id]
    database = Database(db_file=str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'))
    try:
        database.get_table_summaries()
    except ConversionException as e:
        errors.append((e, db_id))
    except ValueError as e:
        errors.append((e, db_id))

  0%|          | 0/166 [00:00<?, ?it/s]

 17%|█▋        | 28/166 [00:04<00:24,  5.73it/s]


TypeMismatchException: Mismatch Type Error: Invalid type in column "If_Affirmative_Win": column was declared as integer, found "F" of type "text" instead.

In [17]:
import sqlite3

In [18]:
db = sqlite3.connect(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'))

c = db.cursor()
c.execute('SELECT * FROM appointment LIMIT 1').fetchall()

[(13216584, 100000001, 101, 1, '2008-04-24 10:00', '2008-04-24 11:00', 'A')]