## Load Data

In [1]:
import json
import numpy as np
import pandas as pd

from pathlib import Path
from src.db_utils import get_schema_str, get_data_dict, get_schema_str_with_tables
from src.database import SqliteDatabase, DuckDBDatabase
from src.sparc_preprocess import (
    load_sparc_data,
    process_all_tables, 
    filter_samples_by_count, 
    process_samples, 
    split_train_dev
)

# duckdb.sql('INSTALL sqlite')
# duckdb.sql('SET GLOBAL sqlite_all_varchar = true;')

proj_path = Path('.').resolve()
sparc_path = proj_path / 'data' / 'sparc'

tables, train_data, dev_data = load_sparc_data(sparc_path)
print(f'Number of train: {len(train_data)} | Number of dev: {len(dev_data)}')

sparc_tables = process_all_tables(tables)
# filter samples by count, must have at least 5 samples
all_data = filter_samples_by_count(train_data+dev_data, n=5)
# process samples -> {db_id: list of samples}
sparc_samples = process_samples(all_data)
# change train/dev by sample
train_samples, dev_samples = split_train_dev(sparc_samples, ratio=0.8)

Number of train: 3034 | Number of dev: 422


In [2]:
db_id = 'hospital_1'
db_file = str(sparc_path / 'database' / db_id / f'{db_id}.sqlite')
database = SqliteDatabase(db_file, foreign_keys=sparc_tables[db_id].foreign_keys)
database.table_cols.keys()

dict_keys(['Physician', 'Department', 'Affiliated_With', 'Procedures', 'Trained_In', 'Patient', 'Nurse', 'Appointment', 'Medication', 'Prescribes', 'Block', 'Room', 'On_Call', 'Stay', 'Undergoes'])

## Workload Analysis

In [3]:
from src.sparc_preprocess import SparcSample, QuestionSQL

def format_interactions(interactions: list[QuestionSQL]) -> str:
    workload = ''
    for i, interaction in enumerate(interactions):
        workload += f'[{i}-Question] {interaction.question}\n[{i}-SQL]: {interaction.sql}\n'
    return workload.strip()

with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

idx = 0
data = train_samples[idx]
table = sparc_tables[data.db_id]
col_explanation = all_descriptions[data.db_id]
# create schema string
schema_str = get_schema_str(
    schema=table.db_schema, 
    foreign_keys=table.foreign_keys,
    primary_keys=table.primary_keys,
    col_explanation=col_explanation
)
database = SqliteDatabase(str(sparc_path / 'database' / data.db_id / f'{data.db_id}.sqlite'), foreign_keys=table.foreign_keys)
workload = format_interactions(data.interactions)
print(workload, '\n')
print(f'[Final]\nQuestion: {data.final.question}\nSQL: {data.final.sql}\n')

[0-Question] What is the number of employees in each department?
[0-SQL]: SELECT count(departmentID) FROM department GROUP BY departmentID
[1-Question] Which department has the most employees? Give me the department name.
[1-SQL]: SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1; 

[Final]
Question: Find the department with the most employees.
SQL: SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;



In [4]:
database.execute(data.final.sql)

Unnamed: 0,Name
0,General Medicine


## Schema description

In [5]:
with (proj_path / 'db_data' / 'sparc_description.json').open() as f:
    all_descriptions = json.load(f)

print(get_schema_str(
    schema=sparc_tables['hospital_1'].db_schema, 
    col_explanation=all_descriptions['hospital_1'])[:300]
)

[Table and Columns]
Table Name: Physician
  - 'EmployeeID'(text): Unique identifier for each physician.
  - 'Name'(text): Full name of the physician.
  - 'Position'(text): Job title or role of the physician.
  - 'SSN'(text): Social Security Number of the physician.
Table Name: Department
  - 'Depart


In [12]:
import os 
from dotenv import load_dotenv, find_dotenv
from collections import defaultdict
from tqdm import tqdm
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

## Add GOOGLE_API_KEY to .env file
_ = load_dotenv(find_dotenv())

## Test with ChatGoogleGenerativeAI

In [39]:
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # other params...
)

class OutputFormat(BaseModel):
    lang: str = Field(description='The target language in UPPERCASE')
    text: str = Field(description='The translated text.')

class Response(BaseModel):
    output: list[OutputFormat]
        
llm = llm.with_structured_output(OutputFormat)

from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant that translates {input_language} to {output_language}.",
        ),
        ("human", "{input}"),
    ]
)
chain = prompt | llm
output = chain.invoke(
    {
        "input_language": "English",
        "output_language": "German",
        "input": "I love programming.",
    }
)
output

OutputFormat(lang='DE', text='Ich liebe Programmieren.')

## Basic Prompt

In [41]:
class OutputFormat(BaseModel):
    full_sql_query: str = Field(description='The full SQL query.')

class Response(BaseModel):
    output: list[OutputFormat]

template = [
("system", '''### TASK
You are tasked with generating a SQL query according to a user input request.

You will be provided an input NL query.

### SCHEMA
You are working with the following schema:
{schema}
'''),
("human", "{input_query}"),
]

prompt = ChatPromptTemplate.from_messages(template)

model_gemini = ChatGoogleGenerativeAI(
    model='gemini-1.5-flash',
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    verbose=True,
)

model = model_gemini.with_structured_output(OutputFormat)
chain = (prompt | model)

all_full_sql = list()
train_subsamples = train_samples[0:10]
for idx in tqdm(range(len(train_subsamples))):
    data = train_subsamples[idx]
    x = data.final
    db_id = data.db_id
    db_schema = get_schema_str(
        schema=sparc_tables[db_id].db_schema, 
        col_explanation=all_descriptions[db_id]
    )
    input_data = {'schema': db_schema, 'input_query': x.question}
    #print(input_data)
    output = chain.invoke(input_data)
    #print(output)
    full_sql_output = {}
    full_sql_output['sql_idx'] = idx
    full_sql_output['db_id'] = db_id
    full_sql_output['question'] = x.question
    full_sql_output['full_sql_query'] = output.full_sql_query
    full_sql_output['gold_sql'] = x.sql
    all_full_sql.append(full_sql_output)
all_full_sql

100%|██████████| 10/10 [00:08<00:00,  1.17it/s]


[{'sql_idx': 0,
  'db_id': 'hospital_1',
  'question': 'Find the department with the most employees.',
  'full_sql_query': 'SELECT Department.Name FROM Department JOIN Affiliated_With ON Department.DepartmentID = Affiliated_With.Department GROUP BY Department.Name ORDER BY COUNT(Affiliated_With.Physician) DESC LIMIT 1',
  'gold_sql': 'SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;'},
 {'sql_idx': 1,
  'db_id': 'hospital_1',
  'question': 'Tell me the employee id of the head of the department with the least employees.',
  'full_sql_query': 'SELECT Head FROM Department GROUP BY Head ORDER BY COUNT(*) ASC LIMIT 1',
  'gold_sql': 'SELECT head FROM department GROUP BY departmentID ORDER BY count(departmentID) LIMIT 1;'},
 {'sql_idx': 2,
  'db_id': 'hospital_1',
  'question': 'Find the name and position of the head of the department with the least employees.',
  'full_sql_query': 'SELECT T1.Name, T1.Position FROM Physician AS T1 JOIN Department A

In [44]:
## database execution evaluation
from src.evaluate import compare_execution

output_results = []
for data in tqdm(all_full_sql, total=len(all_full_sql)):
    sql_idx = data['sql_idx']
    db_id = data['db_id']
    database = SqliteDatabase(str(sparc_path / 'database' / db_id / f'{db_id}.sqlite'), foreign_keys=table.foreign_keys)
    error_info = None
    try:
        pred_result = database.execute(data['full_sql_query'])
        gold_result = database.execute(data['gold_sql'])
        try:
            score = compare_execution(pred_result, gold_result)
        except Exception as e:
            print(f"An error occurred: {e}")
            score = 0
            error_info = 'Python Script Error:' + str(e)
    except Exception as e:
        print(f"An error occurred: {e}")
        score = 0
        error_info = 'Database Execution Error:' + str(e)
    if score == 0 and error_info is None:
        error_info = 'Result Error' 
    output_results.append(
        {
            "instance_id": sql_idx, 
            "score": score,
            "pred_sql": data['full_sql_query'],
            "error_info": error_info
        }
    )
    
print({item['instance_id']: item['score'] for item in output_results})      
score = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {score}")


100%|██████████| 10/10 [00:00<00:00, 64.15it/s]

An error occurred: Execution failed on sql 'SELECT T1.Name, T2.Name FROM Physician AS T1 JOIN Affiliated_With AS T2 ON T1.EmployeeID = T2.Physician WHERE T2.PrimaryAffiliation = "Yes"': no such column: T2.Name
{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 1}
Final score: 0.6





## Chain of Thought Prompt

## Schema Linking: Single-Column Schema Linking (SCSL)
- identifying relevance of a particular column independent of the rest of the schema

## Schema Linking: Table-to-Column Schema Linking (TCSL): 
- first identifying relevant tables then relevant columns.