# Playground

In [18]:
from pathlib import Path
from enum import Enum
import pandas as pd
from typing import Optional
import sqlite3
import json
from func_timeout import func_timeout, FunctionTimedOut


class SQLiteDatabase:
    """ Class for dealing with sqlite3 databases. Provides SQL execution capabilities and access to schema"""
    def __init__(self, db_id: str, input_path: Path, exec_timeout: float = 30.0, use_cached_schema: Optional[Path] = None) -> None:
        """ Attributes
            ----------
                db_id: str
                    name of database; database must exist in input_path/db_id/db_id.sqlite
                input_path: Path
                    parent directory of database folder
                exec_timeout: float
                    maximum number of seconds for query to return a result; aborts returning [(err),]
                schema: dict[str, str]
                    either raw_schema or read from json in path 

                raw_schema: dict[str, str]
                    unaugmented, plain db schemas indexed by table_name, read from db_id.sqlite
                descriptions: dict[str, str]
                    Table descriptions, indexed by table_name, read from table_name.csv 
                    which exist in input_path/db_id/database_description/

                use_cached_schemas: Path | None
                    use pre-generated schema stored in path/to/aug.json provided
                    instead of raw_schema. File contains dict of db_id: schema.
        """
        self.db_id = db_id
        self.input_path = input_path
        self.exec_timeout = exec_timeout

        self.raw_schema: dict[str, str] = self.__fetch_raw_schema()
        self.descriptions: dict[str, str] = self.__fetch_db_descriptions()

        if use_cached_schema:
            with open(use_cached_schema, 'r') as f:
                self.schema = json.load(f)[db_id]
        else:
            self.schema = self.raw_schema

    def __getitem__(self, table_name: str):
        """ Return the schema of a table in the database. """
        return self.schema[table_name] 
    
    def __str__(self):
        """ Returns the database schema as a human-readable/executable string. """
        return "\n\n".join( list(self.schema.values()) )
    
    def run_query(self, sql: str, timeout: Optional[float] = None) -> list[tuple]:
        """ Executes SQL query and fetches all rows. """
        try:
            def execute_sql():
                db_path = (self.input_path / self.db_id / self.db_id).with_suffix('.sqlite')
                with sqlite3.connect(db_path, uri=True) as conn:
                    rows = conn.execute(sql).fetchall()
                return rows

            rows = func_timeout(timeout=(timeout or self.exec_timeout), func=execute_sql)
        except FunctionTimedOut as timeout_error:
            rows = [("Error: timedout", )]
        return rows
    
    def __fetch_raw_schema(self) -> dict[str, str]:
        """ Returns a dict of schema of all tables in a .sqlite database indexed by table name """
        tables = self.run_query("SELECT name FROM sqlite_master WHERE type='table';")
        schemas: dict[str, str] = {}
        for table, in tables:
            if table != "sqlite_sequence":
                schema, = self.run_query(f"SELECT sql FROM sqlite_master WHERE name='{table}';")[0]
                schemas[table] = schema
        return schemas
    
    def __fetch_db_descriptions(self) -> dict[str, str]:
        """ Returns a dict of database_descriptions from each table_name.csv as strings  """
        def case_insensitive_file_reader(filepath: Path):
            content = f'Descriptions file for table at {filepath} does not exist.'
            if filepath.exists():
                with open(filepath, 'r', errors='ignore') as file:
                    content = file.read()
            else:
                file_stem = filepath.stem
                candidate_stems = [
                    stem for stem in (
                        file_stem.capitalize(), file_stem.title(), file_stem.upper(), file_stem.lower()
                    )
                    if filepath.with_stem(stem).exists()
                ]
                if candidate_stems:
                    file_stem = candidate_stems[0]
                    filepath = filepath.with_stem(file_stem)
                    content = case_insensitive_file_reader(filepath)
            return content
        
        descriptions = {}
        for table in self.raw_schema.keys():
            filepath = (self.input_path / self.db_id / 'database_description' / table).with_suffix('.csv')
            descriptions[table] = case_insensitive_file_reader(filepath)

        return descriptions


INPUT_PATH  = Path(f'data/bird-minidev')
BIRD_QUESTION_FILENAME = 'dev.json'
DATABASES_FOLDERNAME = 'dev_databases'
USE_CACHED_SCHEMA = INPUT_PATH / 'aug-minidev/aug.json'       # Use pre-generated schema instead of augmenting with LLM from scratch
DB_EXEC_TIMEOUT = 30.0                                      # maximum number of seconds a query execution is allowed to take

db_names: list[str] = [f.name for f in (INPUT_PATH / DATABASES_FOLDERNAME).iterdir()]
databases: dict[str, SQLiteDatabase] = {
    db_id: SQLiteDatabase(db_id, (INPUT_PATH / DATABASES_FOLDERNAME), DB_EXEC_TIMEOUT, USE_CACHED_SCHEMA) 
    for db_id in db_names
}

In [19]:
print(str(databases['thrombosis_prediction']))

/*
The Examination table is designed to store various medical examination results for patients. 
It includes information related to autoimmune and coagulation disorders, such as anti-Cardiolipin antibodies 
and anti-nucleus antibodies, as well as diagnostic and symptomatic data. This table is linked to the Patient 
table via the ID column, which serves as a foreign key. The relationship ensures that any updates or deletions 
in the Patient table cascade to maintain data integrity across the database.
*/

CREATE TABLE Examination
(
    ID INTEGER null, -- identification of the patient
    `Examination Date` DATE null, -- Examination Date
    `aCL IgG` REAL null, -- anti-Cardiolipin antibody (IgG) concentration
    `aCL IgM` REAL null, -- anti-Cardiolipin antibody (IgM) concentration
    ANA INTEGER null, -- anti-nucleus antibody concentration
    `ANA Pattern` TEXT null, -- pattern observed in the sheet of ANA examination
    `aCL IgA` INTEGER null, -- anti-Cardiolipin antibody (IgA) co

In [None]:
y_pred = '''
SELECT 
    (strftime('%Y', Laboratory.Date) - strftime('%Y', Patient.Birthday)) AS Age,
    Examination.Diagnosis
FROM Laboratory
JOIN Patient ON Laboratory.ID = Patient.ID
JOIN Examination ON Laboratory.ID = Examination.ID AND Laboratory.Date = Examination.`Examination Date`
WHERE Laboratory.HGB = (SELECT MAX(HGB) FROM Laboratory);
'''

y_true = '''
SELECT STRFTIME('%Y', T2.Date) - STRFTIME('%Y', T1.Birthday), T1.Diagnosis 
FROM Patient AS T1 
INNER JOIN Laboratory AS T2 ON T1.ID = T2.ID 
ORDER BY T2.HGB
DESC LIMIT 1'''

def is_sql_same(database,  query_1: str, query_2: str) -> bool:
    """ Executes SQL queries and returns True if outputs match, with no operation errors. """
    try:
        res_1 = database.run_query(query_1)
        print(res_1, flush=True)
        res_2 = database.run_query(query_2)
        print(res_2, flush=True)
    except sqlite3.OperationalError as e:
        print(f"{e.__class__.__name__} {e}")
        return False
    else:
        return set(res_1) == set(res_2)
    
is_sql_same(databases['thrombosis_prediction'], y_true, y_pred)

[(28, 'SLE')]
[]


False

In [22]:
# llm = LLM(SupportedModels.OpenAI.GPT4o, api_key=api_keys.OPENAI_API_KEY, )
# augmenter = SchemaAugmenter(llm, databases, INPUT_PATH/DATABASES_FOLDERNAME)

# raw, aug = augmenter.augment_all(save=True)

Client Instantiated: OpenAI API key valid.


In [24]:
# ddd = {db_id: db.schema for db_id, db in databases.items()}
# augmenter.dump_to_json(INPUT_PATH/DATABASES_FOLDERNAME/"Augmented/db", ddd)


In [None]:
import re
def are_create_statements_equivalent(create_stmt1, create_stmt2):
        # Create an in-memory SQLite database
        conn = sqlite3.connect(':memory:')
        cursor = conn.cursor()

        def replace_table_name(create_stmt, new_table_name):
            # Use regex to replace the table name in the CREATE TABLE statement
            return re.sub(r"CREATE TABLE\s+[`'\"]?[\w]+[`'\"]?", f"CREATE TABLE {new_table_name}", create_stmt, count=1)
        
        def preprocess_sql(sql):
            """Remove comments and extra whitespace from SQL."""
            # Remove multi-line comments
            sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
            # Remove single-line comments
            sql = re.sub(r"--.*?$", "", sql, flags=re.MULTILINE)
            # Strip extra whitespace
            sql = re.sub(r"\s+", " ", sql).replace(" , ", ", ").strip()
            # Change table name so they are the same
            sql = replace_table_name(sql, 'tbl_test')
            return sql
        
        try:
            # Replace table names with unique names
            stmt1 = replace_table_name(preprocess_sql(create_stmt1), "test_table1")
            stmt2 = replace_table_name(preprocess_sql(create_stmt2), "test_table2")

            # Execute the modified CREATE statements
            cursor.execute(stmt1)
            cursor.execute(stmt2)

            # Fetch the schema definitions from sqlite_master
            cursor.execute("SELECT sql FROM sqlite_master WHERE name='test_table1'")
            schema1, = cursor.fetchall()[0]
            schema1 = preprocess_sql(schema1)
            cursor.execute("SELECT sql FROM sqlite_master WHERE name='test_table2'")
            schema2, = cursor.fetchall()[0]
            schema2 = preprocess_sql(schema2)

            # Compare normalized schema definitions
            if schema1.lower() == schema2.lower():
                return True
            else:
                print(schema1)
                print(schema2)
        except sqlite3.Error as e:
            print(f"SQLite error: {e}", flush=True)
            return False
        finally:
            conn.close()


filename = "/home/fahim/Documents/sql-gen/data/bird-minidev/augmented/aug.json"
with open(filename, 'r') as f:
    aug = json.load(f)

for (db_id, db), (aug_id, aug_db) in zip(databases.items(), aug.items()):
    for (tbl, schema), (aug_tbl, aug_schema) in zip(db.schema.items(), aug_db.items()):
        if not are_create_statements_equivalent(schema, aug_schema):
            print(f"Augmented schema {tbl} of {db_id} contains errors.")
            # print(schema)
            # print(aug_schema)
            # print('\n\n')


In [8]:
aug['formula_1']

{'circuits': "/* \n   Table: circuits\n   Purpose: This table stores information about racing circuits, including their unique identifiers, names, locations, and geographic coordinates. \n   Relationship: The circuits table is related to the races table through the circuitId, which is a foreign key in the races table. \n*/\n\nCREATE TABLE circuits\n(\n    circuitId  INTEGER\n        primary key autoincrement, -- unique identification number of the circuit\n    circuitRef TEXT default '' not null, -- circuit reference name\n    name       TEXT default '' not null, -- full name of circuit\n    location   TEXT, -- location of circuit\n    country    TEXT, -- country of circuit\n    lat        REAL, -- latitude of location of circuit\n    lng        REAL, -- longitude of location of circuit\n    alt        INTEGER, -- altitude of the circuit location (not useful)\n    url        TEXT default '' not null -- url\n        unique\n)",
 'constructors': "/*\nThe constructors table stores informa

In [None]:
import pandas as pd

db = '/home/fahim/Documents/sql-gen/data/bird-minidev/dev.json/formula_1/formula_1.sqlite'

df = pd.read_sql_table('circuits', ) 

In [19]:
import json
import pandas as pd

df = pd.read_json('/home/fahim/Documents/sql-gen/data/bird-minidev/dev.json')
df = df[ df['db_id'].isin(['formula_1', 'debit_card_specializing', 'thrombosis_prediction']) ]

contents = {}
files = ['zs', 'op_zs', 'mp', 'op_mp']

for file in files:
    filename = f'/home/fahim/Documents/sql-gen/results/llama3.1:8b_zeroshot-metaprompt-optimizer/{file}_clean.json'
    with open(filename, 'r') as f:
        content = json.load(f)
        df[f"pred_{file}"] = content

df.to_json('/home/fahim/Documents/sql-gen/results/llama3.1:8b_zeroshot-metaprompt-optimizer/final_df_temp.json', orient='records')

In [20]:
df

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty,pred_zs,pred_op_zs,pred_mp,pred_op_mp
0,1471,debit_card_specializing,What is the ratio of customers who pay in EUR ...,ratio of customers who pay in EUR against cust...,"SELECT CAST(SUM(IIF(Currency = 'EUR', 1, 0)) A...",simple,SELECT \n SUM(CASE WHEN Currency = 'EUR' TH...,SELECT \n SUM(CASE WHEN Currency = 'EUR' TH...,SELECT SUM(CASE WHEN Currency = 'EUR' THEN 1 E...,SELECT CAST(SUM(CASE WHEN t1.Currency = 'EUR' ...
1,1472,debit_card_specializing,"In 2012, who had the least consumption in LAM?",Year 2012 can be presented as Between 201201 A...,SELECT T1.CustomerID FROM customers AS T1 INNE...,moderate,SELECT c.CustomerID \nFROM customers c \nJOIN ...,SELECT c.CustomerID \nFROM customers c \nJOIN ...,SELECT CustomerID\nFROM yearmonth\nWHERE Date ...,SELECT CustomerID\nFROM yearmonth\nWHERE Date ...
2,1473,debit_card_specializing,What was the average monthly consumption of cu...,Average Monthly consumption = AVG(Consumption)...,SELECT AVG(T2.Consumption) / 12 FROM customers...,moderate,"SELECT \n c.CustomerID,\n AVG(y.Consumpt...","SELECT \n c.CustomerID,\n AVG(y.Consumpt...","SELECT \n ym.CustomerID,\n AVG(ym.Consum...",WHERE ym.Date >= '20130101' AND ym.Date <= '20...
3,1476,debit_card_specializing,What was the difference in gas consumption bet...,Year 2012 can be presented as Between 201201 A...,"SELECT SUM(IIF(T1.Currency = 'CZK', T2.Consump...",challenging,-- Get Consumption of CZK-paying Customers in ...,SELECT \n SUM(CASE WHEN C.Currency = 'CZK' AN...,SELECT \n (SELECT SUM(Consumption) FROM yea...,SELECT \n SUM(CASE WHEN t1.Currency = 'EUR'...
4,1479,debit_card_specializing,Which year recorded the most consumption of ga...,The first 4 strings of the Date values in the ...,"SELECT SUBSTR(T2.Date, 1, 4) FROM customers AS...",moderate,"SELECT SUBSTR(Date, 1, 4) AS Year, SUM(Consump...",-- Create an index on Date column for better q...,"SELECT \n SUBSTR(Date, 1, 4) AS Year,\n ...","SELECT \n strftime('%Y', Date) AS Year,\n ..."
...,...,...,...,...,...,...,...,...,...,...
241,1001,formula_1,What is full name of the racer who ranked 1st ...,Ranked 1st in the 3rd qualifying race refer to...,"SELECT T2.forename, T2.surname FROM qualifying...",challenging,"SELECT r.name, d.forename, d.surname \nFROM qu...","SELECT r.name, d.forename, d.surname \nFROM qu...","SELECT D.forename, D.surname \nFROM drivers D\...","SELECT D.forename, D.surname \nFROM drivers D\..."
242,1002,formula_1,"As of the present, what is the full name of th...",full name refers to forename+surname; Youngest...,"SELECT T1.forename, T1.surname, T1.nationality...",moderate,"SELECT \n D.forename,\n D.surname,\n ...","SELECT \n D.forename,\n D.surname,\n ...","SELECT D.forename, D.surname, D.nationality, R...","SELECT D.forename, D.surname, D.nationality, R..."
243,1003,formula_1,How many accidents did the driver who had the ...,number of accidents refers to the number where...,SELECT COUNT(T1.driverId) FROM results AS T1 I...,moderate,SELECT raceId FROM races WHERE name LIKE '%Can...,SELECT r.raceID \nFROM races r \nJOIN results ...,"SELECT COUNT(r.statusId) AS NumAccidents, r.ra...","SELECT COUNT(*) AS NumAccidents, r.raceId, r.d..."
244,1011,formula_1,Which top 20 driver created the shortest lap t...,shortest lap time refers to MIN(time); the tim...,"WITH lap_times_in_seconds AS (SELECT driverId,...",challenging,"SELECT D.forename, D.surname, LT.time \nFROM d...","WITH min_lap_time AS (\n SELECT LT.time, D....","SELECT \n D.forename,\n D.surname,\n ...","SELECT \n D.driverId,\n MIN(LT.time) as ..."


In [None]:
# results = '\n\n'.join([zs_report, op_zs_report, mp_report, op_mp_report])
# print(results)

=== EX Results ===
Accuracy :  15.753%
Breakdown by Difficulty:
	simple:  24.490% (12 of 49)
	moderate:  15.385% (10 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  16.438%
Breakdown by Difficulty:
	simple:  22.449% (11 of 49)
	moderate:  18.462% (12 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  17.123%
Breakdown by Difficulty:
	simple:  28.571% (14 of 49)
	moderate:  15.385% (10 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  17.123%
Breakdown by Difficulty:
	simple:  26.531% (13 of 49)
	moderate:  18.462% (12 of 65)
	challenging:  0.000% (0 of 32)
=== end ===



# Run Experiment

## GPT-4o Zero-shot

In [26]:
# if EXPERIMENT == 'zero-shot':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')
    
#     client = get_openai_client()
#     agent = ZeroShotAgent(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)
    
#     # Generate
#     raw_responses = agent.batched_generate(df)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print("Finished Generating. Attempting SQL auto-parsing...")
#     cleaned_sql = agent.auto_parse_sql_from_response(raw_responses)
#     dump_to_json('cleaned_sql', cleaned_sql)
#     print("SQL auto-parsing successful")

#     # Evaluate
#     df['prediction'] = cleaned_sql
#     df['label'] = evaluator.evaluate(df, pred_col_name='prediction')
    
#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

## GPT-4o Zero-shot + Optimizer

In [None]:
# if EXPERIMENT == 'optimizer-agent':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')
    
#     client = get_openai_client()
#     agent = OptimizerAgent(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)
    
#     # Generate
#     df = pd.read_json('gpt-4o_zero-shot_df.json')
#     raw_responses = agent.batched_generate(df)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print(f"Finished Generating. Attempting SQL auto-parsing...")
#     cleaned_sql = agent.auto_parse_sql_from_response(raw_responses)
#     dump_to_json('cleaned_sql', cleaned_sql)
#     print(f"SQL auto-parsing successful")

#     # Evaluate
#     df['optimized'] = cleaned_sql
#     df['opt-label'] = evaluator.evaluate(df, pred_col_name='optimized')
    
#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

## GPT-4o Multi-Agent Discussion

In [None]:
# if EXPERIMENT == 'discussion':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')

#     client = get_openai_client()
#     multi_agent = MultiAgentDiscussion(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)


#     # Generate
#     raw_responses = multi_agent.batched_generate(df, rounds=3)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print(f"Finished Generating. Attempting SQL auto-parse...")

#     starter_zero = multi_agent.auto_parse_sql_from_response([response['agent_zero_shot'][0] for response in raw_responses])
#     dump_to_json('cleaned_zeroshot_starter', starter_zero)

#     starter_meta = multi_agent.auto_parse_sql_from_response([response['agent_meta_prompt'][0] for response in raw_responses])
#     dump_to_json('cleaned_starter_meta', starter_meta)
    
#     cleaned_sql  = multi_agent.auto_parse_sql_from_response([response['verdict'] for response in raw_responses])
#     dump_to_json('cleaned_sql', cleaned_sql)

#     print(f"SQL auto-parsing successful\n\n")


#     # Evaluate results
#     print("Evaluating Zero-shot starter generated queries")
#     df['starter_zero_shot'] = starter_zero
#     df['zero_shot_labels']  = evaluator.evaluate(df, pred_col_name='starter_zero_shot')

#     print("Evaluating meta-prompt starter generated queries")
#     df['starter_meta_prompt'] = starter_meta
#     df['meta_prompt_labels']  = evaluator.evaluate(df, pred_col_name='starter_meta_prompt')

#     print("Evaluating Multi-Agent Discussion generated queries")
#     df['prediction'] = cleaned_sql
#     df['label']      = evaluator.evaluate(df, pred_col_name='prediction')


#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

# Experiments:
- Zero Shot
    - with/without COT
- Optimizer (on top of zero-shot)
- Multi-agent:
    - Zero-shot -> Optimizer -> Multi-agent Debate
    - Zero-shot -> Optimizer -> Multi-agent Discussion
    - Best of the above -> Optimizer
- Decomposition and Generation via Multi-agent Debate/Discussion
- Sparse Topology Multi-agent Debate/Discussion
- Augmenting schema with LLM calls:
    - Point out relationships (graph idea)
    - Write short descriptions regarding tables, columns