# Playground

In [2]:
from pathlib import Path
import pandas as pd
from core.dbhandler import SQLiteDatabase
from core.birdeval import evaluate

INPUT_PATH  = Path(f'data/bird-minidev')
BIRD_QUESTION_FILENAME = 'dev.json'
DATABASES_FOLDERNAME = 'dev_databases'
USE_CACHED_SCHEMA = INPUT_PATH / 'aug-minidev/aug.json'       # Use pre-generated schema instead of augmenting with LLM from scratch
DB_EXEC_TIMEOUT = 30.0                                      # maximum number of seconds a query execution is allowed to take

db_names: list[str] = [f.name for f in (INPUT_PATH / DATABASES_FOLDERNAME).iterdir()]
databases: dict[str, SQLiteDatabase] = {
    db_id: SQLiteDatabase(db_id, (INPUT_PATH / DATABASES_FOLDERNAME), DB_EXEC_TIMEOUT, USE_CACHED_SCHEMA) 
    for db_id in db_names
}

In [4]:
df = pd.read_json('/home/fahim/Documents/sql-gen/data/bird-minidev/dev.json')[300:].reset_index()
clean = pd.read_json('/home/fahim/Documents/sql-gen/results/starter1_clean.json')
df['clean'] = clean
df
# labels, report = evaluate(df, databases, 30, f'clean')

Unnamed: 0,index,question_id,db_id,question,evidence,SQL,difficulty,clean
0,300,539,codebase_community,"Who is the owner of the post ""Eliciting priors...","""Eliciting priors from experts"" is the Title o...",SELECT T2.DisplayName FROM posts AS T1 INNER J...,simple,SELECT u.DisplayName \nFROM posts p \nINNER JO...
1,301,537,codebase_community,How many posts does the user csgillespie own?,"""csgillespie"" is the DisplayName of user",SELECT COUNT(T1.id) FROM posts AS T1 INNER JOI...,simple,SELECT COUNT(T2.DisplayName) \nFROM users AS T...
2,302,544,codebase_community,What is the display name of the user who last ...,"""Examples for teaching: Correlation does not m...",SELECT T2.DisplayName FROM posts AS T1 INNER J...,moderate,SELECT T2.DisplayName \nFROM posts AS T1 \nINN...
3,303,547,codebase_community,"Among the posts owned by an elder user, how ma...",elder users refers to Age > 65; Score of over ...,SELECT COUNT(T1.Id) FROM posts AS T1 INNER JOI...,simple,SELECT COUNT(DISTINCT T2.Id) \nFROM users AS T...
4,304,549,codebase_community,"From which post is the tag ""bayesian"" excerpte...","""bayesian"" is the TagName; excerpt from refers...",SELECT T2.Body FROM tags AS T1 INNER JOIN post...,simple,SELECT p.Body \nFROM tags t \nJOIN posts p ON ...
...,...,...,...,...,...,...,...,...
195,495,173,financial,How often does account number 3 request an acc...,k_symbol refers to the purpose of payments,"SELECT T1.frequency, T2.k_symbol FROM account ...",challenging,-- Query to find the frequency of account stat...
196,496,186,financial,What percentage of male clients request for we...,Percentage of male clients = [count(male clien...,SELECT CAST(SUM(T1.gender = 'M') AS REAL) * 10...,moderate,SELECT \n CAST(SUM(CASE WHEN T1.gender = 'M...
197,497,189,financial,Name the account numbers of female clients who...,Female refers to 'F' in the gender; A11 contai...,SELECT T3.account_id FROM client AS T1 INNER J...,moderate,SELECT DISTINCT T3.account_id \nFROM client AS...
198,498,192,financial,What is the average amount of loan which are s...,"status = 'C' stands for running contract, OK s...",SELECT AVG(T2.amount) FROM account AS T1 INNER...,moderate,SELECT AVG(amount) \nFROM loan \nWHERE status ...


In [19]:
clean.columns

RangeIndex(start=0, stop=1, step=1)

In [None]:
y_pred = '''
SELECT 
    (strftime('%Y', Laboratory.Date) - strftime('%Y', Patient.Birthday)) AS Age,
    Examination.Diagnosis
FROM Laboratory
JOIN Patient ON Laboratory.ID = Patient.ID
JOIN Examination ON Laboratory.ID = Examination.ID AND Laboratory.Date = Examination.`Examination Date`
WHERE Laboratory.HGB = (SELECT MAX(HGB) FROM Laboratory);
'''

y_true = '''
SELECT STRFTIME('%Y', T2.Date) - STRFTIME('%Y', T1.Birthday), T1.Diagnosis 
FROM Patient AS T1 
INNER JOIN Laboratory AS T2 ON T1.ID = T2.ID 
ORDER BY T2.HGB
DESC LIMIT 1'''

def is_sql_same(database,  query_1: str, query_2: str) -> bool:
    """ Executes SQL queries and returns True if outputs match, with no operation errors. """
    try:
        res_1 = database.run_query(query_1)
        print(res_1, flush=True)
        res_2 = database.run_query(query_2)
        print(res_2, flush=True)
    except sqlite3.OperationalError as e:
        print(f"{e.__class__.__name__} {e}")
        return False
    else:
        return set(res_1) == set(res_2)
    
is_sql_same(databases['thrombosis_prediction'], y_true, y_pred)

[(28, 'SLE')]
[]


False

In [None]:
# results = '\n\n'.join([zs_report, op_zs_report, mp_report, op_mp_report])
# print(results)

=== EX Results ===
Accuracy :  15.753%
Breakdown by Difficulty:
	simple:  24.490% (12 of 49)
	moderate:  15.385% (10 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  16.438%
Breakdown by Difficulty:
	simple:  22.449% (11 of 49)
	moderate:  18.462% (12 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  17.123%
Breakdown by Difficulty:
	simple:  28.571% (14 of 49)
	moderate:  15.385% (10 of 65)
	challenging:  3.125% (1 of 32)
=== end ===


=== EX Results ===
Accuracy :  17.123%
Breakdown by Difficulty:
	simple:  26.531% (13 of 49)
	moderate:  18.462% (12 of 65)
	challenging:  0.000% (0 of 32)
=== end ===



# Run Experiment

## GPT-4o Zero-shot

In [26]:
# if EXPERIMENT == 'zero-shot':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')
    
#     client = get_openai_client()
#     agent = ZeroShotAgent(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)
    
#     # Generate
#     raw_responses = agent.batched_generate(df)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print("Finished Generating. Attempting SQL auto-parsing...")
#     cleaned_sql = agent.auto_parse_sql_from_response(raw_responses)
#     dump_to_json('cleaned_sql', cleaned_sql)
#     print("SQL auto-parsing successful")

#     # Evaluate
#     df['prediction'] = cleaned_sql
#     df['label'] = evaluator.evaluate(df, pred_col_name='prediction')
    
#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

## GPT-4o Zero-shot + Optimizer

In [None]:
# if EXPERIMENT == 'optimizer-agent':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')
    
#     client = get_openai_client()
#     agent = OptimizerAgent(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)
    
#     # Generate
#     df = pd.read_json('gpt-4o_zero-shot_df.json')
#     raw_responses = agent.batched_generate(df)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print(f"Finished Generating. Attempting SQL auto-parsing...")
#     cleaned_sql = agent.auto_parse_sql_from_response(raw_responses)
#     dump_to_json('cleaned_sql', cleaned_sql)
#     print(f"SQL auto-parsing successful")

#     # Evaluate
#     df['optimized'] = cleaned_sql
#     df['opt-label'] = evaluator.evaluate(df, pred_col_name='optimized')
    
#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

## GPT-4o Multi-Agent Discussion

In [None]:
# if EXPERIMENT == 'discussion':
#     print(f"Experiment: {MODEL}_{EXPERIMENT}")
    
#     # Setup
#     df, db_names = read_dataset()
#     db_schemas   = fetch_BIRD_schemas(db_names)
#     print(f'{db_names=}, {len(df)=}')

#     client = get_openai_client()
#     multi_agent = MultiAgentDiscussion(MODEL, client, get_db_cursor, db_schemas, OUTPUT_PATH)
#     evaluator = EvaluatorForBIRD(get_db_cursor)


#     # Generate
#     raw_responses = multi_agent.batched_generate(df, rounds=3)
#     dump_to_json('raw_responses', raw_responses)

#     # Parse
#     print(f"Finished Generating. Attempting SQL auto-parse...")

#     starter_zero = multi_agent.auto_parse_sql_from_response([response['agent_zero_shot'][0] for response in raw_responses])
#     dump_to_json('cleaned_zeroshot_starter', starter_zero)

#     starter_meta = multi_agent.auto_parse_sql_from_response([response['agent_meta_prompt'][0] for response in raw_responses])
#     dump_to_json('cleaned_starter_meta', starter_meta)
    
#     cleaned_sql  = multi_agent.auto_parse_sql_from_response([response['verdict'] for response in raw_responses])
#     dump_to_json('cleaned_sql', cleaned_sql)

#     print(f"SQL auto-parsing successful\n\n")


#     # Evaluate results
#     print("Evaluating Zero-shot starter generated queries")
#     df['starter_zero_shot'] = starter_zero
#     df['zero_shot_labels']  = evaluator.evaluate(df, pred_col_name='starter_zero_shot')

#     print("Evaluating meta-prompt starter generated queries")
#     df['starter_meta_prompt'] = starter_meta
#     df['meta_prompt_labels']  = evaluator.evaluate(df, pred_col_name='starter_meta_prompt')

#     print("Evaluating Multi-Agent Discussion generated queries")
#     df['prediction'] = cleaned_sql
#     df['label']      = evaluator.evaluate(df, pred_col_name='prediction')


#     # Save results
#     df.to_json(OUTPUT_PATH / f'{MODEL}_{EXPERIMENT}_df.json', orient='records')

# Experiments:
- Zero Shot
    - with/without COT
- Optimizer (on top of zero-shot)
- Multi-agent:
    - Zero-shot -> Optimizer -> Multi-agent Debate
    - Zero-shot -> Optimizer -> Multi-agent Discussion
    - Best of the above -> Optimizer
- Decomposition and Generation via Multi-agent Debate/Discussion
- Sparse Topology Multi-agent Debate/Discussion
- Augmenting schema with LLM calls:
    - Point out relationships (graph idea)
    - Write short descriptions regarding tables, columns