## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict, get_schema_str_with_tables
from src.database import SqliteDatabase, DuckDBDatabase
from src.sparc_preprocess import (
    load_sparc_data,
    process_all_tables, 
    filter_samples_by_count, 
    process_samples, 
    split_train_dev
)

# duckdb.sql('INSTALL sqlite')
# duckdb.sql('SET GLOBAL sqlite_all_varchar = true;')

proj_path = Path('.').resolve()
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

sparc_tables = process_all_tables(tables)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

Number of train: 3034 | Number of dev: 422


In [2]:
db_id = 'hospital_1'
db_file = str(sparc_path / 'database' / db_id / f'{db_id}.sqlite')
database = SqliteDatabase(db_file, foreign_keys=sparc_tables[db_id].foreign_keys)
database.table_cols.keys()

dict_keys(['Physician', 'Department', 'Affiliated_With', 'Procedures', 'Trained_In', 'Patient', 'Nurse', 'Appointment', 'Medication', 'Prescribes', 'Block', 'Room', 'On_Call', 'Stay', 'Undergoes'])

In [3]:
database.execute('SELECT * FROM Department LIMIT 5;')

Unnamed: 0,DepartmentID,Name,Head
0,1,General Medicine,4
1,2,Surgery,7
2,3,Psychiatry,9


## Workload Analysis

In [4]:
from src.sparc_preprocess import SparcSample, QuestionSQL

def format_interactions(interactions: list[QuestionSQL]) -> str:
    workload = ''
    for i, interaction in enumerate(interactions):
        workload += f'[{i}-Question] {interaction.question}\n[{i}-SQL]: {interaction.sql}\n'
    return workload.strip()

with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

idx = 0
data = train_samples[idx]
table = sparc_tables[data.db_id]
col_explanation = all_descriptions[data.db_id]
# create schema string
schema_str = get_schema_str(
    schema=table.db_schema, 
    foreign_keys=table.foreign_keys,
    primary_keys=table.primary_keys,
    col_explanation=col_explanation
)
database = SqliteDatabase(str(sparc_path / 'database' / data.db_id / f'{data.db_id}.sqlite'), foreign_keys=table.foreign_keys)
workload = format_interactions(data.interactions)
print(workload, '\n')
print(f'[Final]\nQuestion: {data.final.question}\nSQL: {data.final.sql}\n')

[0-Question] What is the number of employees in each department?
[0-SQL]: SELECT count(departmentID) FROM department GROUP BY departmentID
[1-Question] Which department has the most employees? Give me the department name.
[1-SQL]: SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1; 

[Final]
Question: Find the department with the most employees.
SQL: SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;



In [5]:
database.execute(data.interactions[0].sql)

Unnamed: 0,count(departmentID)
0,1
1,1
2,1


In [6]:
database.execute(data.interactions[1].sql)

Unnamed: 0,Name
0,General Medicine


In [7]:
database.execute(data.final.sql)

Unnamed: 0,Name
0,General Medicine


## Schema description

In [5]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(
    schema=sparc_tables['hospital_1'].db_schema, 
    col_explanation=all_descriptions['hospital_1'])[:300]
)

[Table and Columns]
Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'Depart


In [6]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

_ = load_dotenv(find_dotenv())

## Basic Prompt

In [11]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
train_subsamples = train_samples[0:10]
for idx in tqdm(range(len(train_subsamples))):
    data = train_subsamples[idx]
    x = data.final
    db_id = data.db_id
    db_schema = get_schema_str(
        schema=sparc_tables[db_id].db_schema, 
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x.question
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x.sql
    all_full_sql.append(full_sql_output)
all_full_sql

100%|██████████| 10/10 [00:09<00:00,  1.05it/s]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'full_sql_query': 'SELECT d.Name, COUNT(a.Physician) AS EmployeeCount\nFROM Department d\nJOIN Affiliated_With a ON d.DepartmentID = a.Department\nGROUP BY d.Name\nORDER BY EmployeeCount DESC\nLIMIT 1;',
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 1,
  'db_id': 'hospital_1',
  'question': 'Tell me the employee id of the head of the department with the least employees.',
  'full_sql_query': 'SELECT Head FROM Department WHERE DepartmentID = (SELECT DepartmentID FROM Affiliated_With GROUP BY Department ORDER BY COUNT(Physician) ASC LIMIT 1)',
  'gold_sql': 'SELECT head FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;'},
 {'sql_idx': 2,
  'db_id': 'hospital_1',
  'question': 'Find the name and position of the head of the department with the least employees.',
  'full_sql_query': "S

In [14]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)
    pred_result = database.execute(data['full_sql_query'])
    #print(pred_result)
    gold_result = database.execute(data['gold_sql'])
    #print(gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


100%|██████████| 10/10 [00:00<00:00, 65.01it/s]

{0: 1, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 0, 7: 0, 8: 0, 9: 0}
Final score: 0.3





## Chain of Thought Prompt

In [15]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')
    rationale: str = Field(description='The step-by-step reasoning to generate the SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with generating a SQL query according to a user input request.
You should work in step-by-step reasoning before coming to the full SQL query.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the step-by-step reasoning to generate the SQL query>",
    "full_sql_query": "<str: the full SQL query>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['schema', 'input_query']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_full_sql = list()
train_subsamples = train_samples[0:10]
for idx in tqdm(range(len(train_subsamples))):
    data = train_subsamples[idx]
    x = data.final
    db_id = data.db_id
    db_schema = get_schema_str(
        schema=sparc_tables[db_id].db_schema, 
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x.question
    full_sql_output['rationale'] = output[0].rationale
    full_sql_output['full_sql_query'] = output[0].full_sql_query
    full_sql_output['gold_sql'] = x.sql
    all_full_sql.append(full_sql_output)
all_full_sql

100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'rationale': "To find the department with the most employees, we need to count the number of physicians affiliated with each department. We can achieve this by joining the 'Department' table with the 'Affiliated_With' table on the department identifier. We will group the results by department name and count the number of physicians in each department. Finally, we will order the results in descending order based on the count and limit the output to the top result.",
  'full_sql_query': 'SELECT d.Name, COUNT(a.Physician) AS EmployeeCount\nFROM Department d\nJOIN Affiliated_With a ON d.DepartmentID = a.Department\nGROUP BY d.Name\nORDER BY EmployeeCount DESC\nLIMIT 1;',
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 1,
  'db_id': 'hospital_1',
  'question': 'Tell me the employee id of the head of the departmen

In [16]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)
    pred_result = database.execute(data['full_sql_query'])
    #print(pred_result)
    gold_result = database.execute(data['gold_sql'])
    #print(gold_result)
    error_info = None
    try:
        score = compare_execution(pred_result, gold_result)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Python Script Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


100%|██████████| 10/10 [00:00<00:00, 63.54it/s]

{0: 1, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 0, 7: 0, 8: 0, 9: 1}
Final score: 0.4





## Schema Linking: Single-Column Schema Linking (SCSL)
- identifying relevance of a particular column independent of the rest of the schema

In [29]:
class OutputFormat(BaseModel):
    rationale: str = Field(description='The reasoning behind decision.')
    relevant: bool = Field(description='relevant or not.')
    column: str = Field(description='The column name and its type.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with identifying whether or 
not a candidate column from a schema is 
related to the provided input request.

You will be provided:
- An input NL query.
- The schema

### FORMATTING
Your output should be a list of the following JSON format:
[{{
    "column": "<str: the column name and its type>",
    "rationale": "<str: the reasoning behind decision>",
    "relevant": "<bool: relevant or not>"
}},...]

### OUTPUT
<INPUT QUERY>: {input_query}
<SCHEMA>: {schema}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['input_query', 'schema']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_results = list()
train_subsamples = train_samples[0:1]
for idx in tqdm(range(len(train_subsamples))):
    data = train_subsamples[idx]
    x = data.final
    db_id = data.db_id
    db_schema = get_schema_str(
        schema=sparc_tables[db_id].db_schema, 
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    #print(input_data)
    outputs = chain.invoke(input=input_data).output
    #print(output)
    for output in outputs:
        full_output = {}
        full_output['sql_idx'] = idx
        full_output['db_id'] = db_id
        full_output['question'] = x.question
        full_output['column'] = output.column
        full_output['rationale'] = output.rationale
        full_output['relevant'] = output.relevant
        full_output['gold_sql'] = x.sql
        all_results.append(full_output)
all_results

100%|██████████| 1/1 [00:04<00:00,  4.98s/it]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'column': 'DepartmentID(number)',
  'rationale': 'The query is asking for information about departments, specifically which one has the most employees. The DepartmentID is essential for identifying departments in the schema.',
  'relevant': True,
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'column': 'Name(number)',
  'rationale': "The query requires the name of the department with the most employees, making this column relevant as it provides the department's name.",
  'relevant': True,
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'column': 'Head

## Schema Linking: Table-to-Column Schema Linking (TCSL): 
- first identifying relevant tables then relevant columns.

In [7]:
class OutputFormat(BaseModel):
    rationale: str = Field(description='The reasoning behind decision.')
    tables: list[str] = Field(description='List of relevant tables.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with identifying which tables 
from a schema are related to the provided input request.

You will be provided:
- An input NL query.
- The schema

### FORMATTING
Your output should be of the following JSON format:
{{
    "rationale": "<str: the reasoning behind decision>",
    "tables": "<list[str]: relevant tables>"
}}

### OUTPUT
<INPUT QUERY>: {input_query}
<SCHEMA>: {schema}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['input_query', 'schema']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_results = list()
train_subsamples = train_samples[0:1]
for idx in tqdm(range(len(train_subsamples))):
    data = train_subsamples[idx]
    x = data.final
    db_id = data.db_id
    db_schema = get_schema_str(
        schema=sparc_tables[db_id].db_schema, 
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    #print(input_data)
    output = chain.invoke(input=input_data).output
    #print(output)
    full_output = {}
    full_output['sql_idx'] = idx
    full_output['db_id'] = db_id
    full_output['question'] = x.question
    full_output['tables'] = output[0].tables
    full_output['rationale'] = output[0].rationale
    full_output['gold_sql'] = x.sql
    all_results.append(full_output)
all_results

100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'tables': ['Department', 'Affiliated_With', 'Physician'],
  'rationale': "To find the department with the most employees, we need to look at the 'Department' table to identify departments and the 'Affiliated_With' table to link physicians to their respective departments. The 'Physician' table is also relevant as it contains the employee identifiers for physicians, which will help in counting the number of employees in each department.",
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'}]

In [8]:
class OutputFormat(BaseModel):
    rationale: str = Field(description='The reasoning behind decision.')
    table: str = Field(description='The relevant table.')
    columns: list[str] = Field(description='List of relevant columns.')

class Response(BaseModel):
    output: list[OutputFormat]

template = '''### TASK
You are tasked with identifying which columns 
from a schema are related to the provided input request.

You will be provided:
- An input NL query.
- The schema

### FORMATTING
Your output should be list of the following JSON format:
[{{
    "rationale": "<str: the reasoning behind decision>",
    "table": "<str: the table name>",
    "columns": "<list[str]: relevant columns>"
}},...]

### OUTPUT
<INPUT QUERY>: {input_query}
<SCHEMA>: {schema}
<OUTPUT>: 
'''

prompt = PromptTemplate(
    template=template,
    input_variables=['input_query', 'schema']
)

model_openai = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0.0,
)

model = model_openai.with_structured_output(Response)
chain = (prompt | model)

all_columns_results = list()
for idx in tqdm(range(len(all_results))):
    data = all_results[idx]
    db_id = data['db_id']
    db_schema = get_schema_str_with_tables(
        schema=sparc_tables[db_id].db_schema,
        table_list=data['tables'],
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    print(input_data)
    outputs = chain.invoke(input=input_data).output
    #print(outputs)
    for output in outputs:
        full_output = {}
        full_output['sql_idx'] = data['sql_idx']
        full_output['db_id'] = db_id
        full_output['question'] = data['question']
        full_output['table'] = output.table
        full_output['columns'] = output.columns
        full_output['rationale'] = output.rationale
        full_output['gold_sql'] = data['gold_sql']
        all_columns_results.append(full_output)
all_columns_results

  0%|          | 0/1 [00:00<?, ?it/s]

{'schema': "[Table and Columns]\nTable Name: Physician\n  - 'EmployeeID'(text): Unique identifier for each physician.\n  - 'Name'(text): Full name of the physician.\n  - 'Position'(text): Job title or role of the physician.\n  - 'SSN'(text): Social Security Number of the physician.\nTable Name: Department\n  - 'DepartmentID'(number): Unique identifier for each department.\n  - 'Name'(number): Name of the department.\n  - 'Head'(number): Identifier for the head of the department.\nTable Name: Affiliated_With\n  - 'Physician'(text): Identifier for the physician.\n  - 'Department'(text): Identifier for the department.\n  - 'PrimaryAffiliation'(text): Indicates if this is the primary affiliation.", 'input_query': 'Find the department with the most employees.'}


100%|██████████| 1/1 [00:01<00:00,  1.80s/it]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'table': 'Department',
  'columns': ['DepartmentID', 'Name'],
  'rationale': "The query asks for the department with the most employees, which requires information about departments and their associated physicians. The 'Affiliated_With' table links physicians to departments, and the 'Department' table provides the department details.",
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'table': 'Affiliated_With',
  'columns': ['Physician', 'Department'],
  'rationale': "The 'Affiliated_With' table is crucial as it connects physicians to their respective departments, allowing us to count the number of employees (physicians) in each department.",
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY cou