In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

proj_path = Path('.').resolve()

# Process data

In [3]:
import sqlparse
from src.database import SqliteDatabase
from src.eval import result_eq, check_if_exists_orderby
from src.eval_complexity import eval_all
from src.process_sql import get_schema, Schema
from src.parsing_sql import (
    extract_selection, 
    extract_condition, 
    extract_aggregation, 
    extract_nested_setoperation, 
    extract_others,
    extract_aliases,
)

def error_check(proj_path, all_tasks):
    error_infos = {
        'pred_exec': [],
        'result': [],
        'parsing_sql': [],
        'error_samples': set(),
        'empty_hint': set()
    }

    # filter parsing errors
    for task in all_tasks:
        with open(proj_path / 'experiments' / f'{task}.jsonl', 'r') as f:
            iterator = tqdm(f, desc=task)
            for line in iterator:
                x = json.loads(line)
                if x['hint'] == '':
                    error_infos['empty_hint'].add(x['sample_id'])
                has_error = False
                schema = get_schema(str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite'))
                schema = Schema(schema)
                
                parsed_result = {}
                for s in ['gold', 'pred']:
                    try:
                        sql = x[f'{s}_sql']
                        statement = sqlparse.parse(sql.strip())[0]
                        aliases = extract_aliases(statement)
                        selection = extract_selection(statement, aliases, schema)
                        condition = extract_condition(statement)
                        aggregation = extract_aggregation(statement, aliases, schema)
                        nested = extract_nested_setoperation(statement)
                        others = extract_others(statement, aliases, schema)
                        
                        parsed_result[s + '_selection'] = selection
                        parsed_result[s + '_condition'] = condition
                        parsed_result[s + '_aggregation'] = aggregation
                        parsed_result[s + '_nested'] = nested
                        parsed_result[s + '_others'] = {
                            'distinct': others['distinct'], 
                            'order by': others['order by'], 
                            'limit': others['limit']
                        }
                    except Exception as e:
                        has_error = True
                        error_infos['parsing_sql'].append((x['sample_id'], s, str(e)))
                        error_infos['error_samples'].add(x['sample_id'])
                        break
                
                if has_error:
                    continue

                iterator.update()
                iterator.set_description_str(f'{task} | error samples {len(error_infos["error_samples"])} | empty hints {len(error_infos["empty_hint"])}')

    print(f'Parsing SQL errors: {len(error_infos["parsing_sql"])}')

    return error_infos

# process single task
def process_task(task, error_infos):
    task_results = {
        'sample_id': [],
        'score': [],
        's_sel': [], 's_cond': [], 's_agg': [], 's_nest': [], 's_oth': [],
    }
    with open(proj_path / 'experiments' / f'{task}.jsonl', 'r') as f:
        iterator = tqdm(f, desc=task.lstrip('sql_gen_hint_top'))
        for line in iterator:
            x = json.loads(line)
            if x['sample_id'] in error_infos['error_samples']:
                continue

            task_results['sample_id'].append(x['sample_id'])
            # parsing sql
            schema = get_schema(str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite'))
            schema = Schema(schema)
            
            parsed_result = {}
            for s in ['gold', 'pred']:
                sql = x[f'{s}_sql']
                statement = sqlparse.parse(sql.strip())[0]
                aliases = extract_aliases(statement)
                selection = extract_selection(statement, aliases, schema)
                condition = extract_condition(statement)
                aggregation = extract_aggregation(statement, aliases, schema)
                nested = extract_nested_setoperation(statement)
                others = extract_others(statement, aliases, schema)
                
                parsed_result[s + '_selection'] = selection
                parsed_result[s + '_condition'] = condition
                parsed_result[s + '_aggregation'] = aggregation
                parsed_result[s + '_nested'] = nested
                parsed_result[s + '_others'] = {
                    'distinct': others['distinct'], 
                    'order by': others['order by'], 
                    'limit': others['limit']
                }

            # partial & complexity eval
            eval_res = eval_all(parsed_result, k=6)
            task_results['s_sel'].append(eval_res['score']['selection'])
            task_results['s_cond'].append(eval_res['score']['condition'])
            task_results['s_agg'].append(eval_res['score']['aggregation'])
            task_results['s_nest'].append(eval_res['score']['nested'])
            task_results['s_oth'].append(eval_res['score']['others'])
            # Execution
            database = SqliteDatabase(
                str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite')
            )
            error_info = ''
            try:
                pred_result = database.execute(x['pred_sql'], rt_pandas=False)
            except Exception as e:
                pred_result = []
                error_info = 'Predction Execution Error:' + str(e)
                score = 0

            try:
                gold_result = database.execute(x['gold_sql'], rt_pandas=False)
            except Exception as e:
                error_info = 'Gold Execution Error:' + str(e)

            if 'Gold Execution Error' in error_info:
                continue
            elif 'Predction Execution Error' in error_info:
                task_results['score'].append(score)
                continue
            else:
                exists_orderby = check_if_exists_orderby(x['gold_sql'])
                score = int(result_eq(pred_result, gold_result, order_matters=exists_orderby))
                task_results['score'].append(score)

    return task_results

def process_all_exps(proj_path, all_tasks):
    error_infos = error_check(proj_path, all_tasks)
    for task in all_tasks:
        task_results = process_task(task, error_infos)
        pd.DataFrame(task_results).to_csv(proj_path / 'experiments' / 'bo_evals' / f'{task}.csv', index=False)

sql_gen_hint_top1_1_desc | error samples 0 | empty hints 0: 401it [00:01, 256.68it/s]
sql_gen_hint_top1_2_desc | error samples 1 | empty hints 0: 401it [00:01, 265.86it/s]
sql_gen_hint_top1_3+_desc | error samples 1 | empty hints 0: 401it [00:01, 264.61it/s]
sql_gen_hint_top3_1_desc | error samples 1 | empty hints 0: 401it [00:01, 260.51it/s]
sql_gen_hint_top3_2_desc | error samples 3 | empty hints 0: 401it [00:01, 259.03it/s]
sql_gen_hint_top3_3+_desc | error samples 3 | empty hints 0: 401it [00:01, 261.92it/s]
sql_gen_hint_top1_1_descvt | error samples 3 | empty hints 0: 401it [00:01, 256.23it/s]
sql_gen_hint_top1_2_descvt | error samples 4 | empty hints 0: 401it [00:01, 257.04it/s]
sql_gen_hint_top1_3+_descvt | error samples 5 | empty hints 0: 401it [00:01, 240.18it/s]
sql_gen_hint_top3_1_descvt | error samples 6 | empty hints 0: 401it [00:01, 266.58it/s]
sql_gen_hint_top3_2_descvt | error samples 6 | empty hints 0: 401it [00:01, 261.93it/s]
sql_gen_hint_top3_3+_descvt | error sampl

Parsing SQL errors: 11





In [69]:
all_tasks = []
typ = '_c'  # '_c'
iterator = ['low', 'mid', 'high'] if typ == '_c' else ['1', '2', '3+']
for typ2 in ['desc', 'descvt']:        
    for n_retrieval in [1, 3]:
        for level in iterator:
            all_tasks.append(f'sql_gen_hint_top{n_retrieval}_{level}_{typ2}')

# process_all_exps(proj_path, all_tasks)

# Aggregating the results

In [80]:
# real eval
typ = '_c' # '_t', '_c'
iterator = ['low', 'mid', 'high'] if typ == '_c' else ['1', '2', '3+']
all_tasks = []
for typ2 in ['desc', 'descvt']:        
    for n_retrieval in [1, 3]:
        for level in iterator:
            all_tasks.append(f'sql_gen_hint_top{n_retrieval}_{level}_{typ2}')

col = 'cate_gold_c' if typ == '_c' else 'cate_len_tbls'
results = {
    'bo_topk': [], 'bo_level': [], 'bo_desc_vt': [], 
    'count': [], 'ex_acc': [], 'pm_sel': [], 'pm_cond': [], 'pm_agg': [], 'pm_nest': [], 'pm_oth': [],
}
for l in iterator:
    results[f'ex_acc_{l}'] = []
    results[f'count_{l}'] = []
    for c in ['sel', 'cond', 'agg', 'nest', 'oth']:
        results[f'pm_{c}_{l}'] = []

df_test = pd.read_csv(proj_path / 'data' / 'spilt_in_domain' / f'bo{typ}_eval.csv')
baseline = df_test.groupby(col)[['score', 's_sel', 's_cond', 's_agg', 's_nest', 's_oth']].mean() * 100

for task in all_tasks:
    print(task)
    # results['task'].append(task)
    results['bo_topk'].append(int(task.lstrip('sql_gen_hint_top').split('_')[0]))
    results['bo_level'].append(task.lstrip('sql_gen_hint_top').split('_')[1])
    results['bo_desc_vt'].append(task.lstrip('sql_gen_hint_top').split('_')[2])

    df = pd.read_csv(proj_path / 'experiments' / 'bo_evals' / f'{task}.csv')
    df = pd.merge(df, df_test.loc[:, ['sample_id', 'cate_len_tbls', 'cate_gold_c']], on='sample_id', how='left')
    results['count'].append(df.shape[0])
    results['ex_acc'].append(df['score'].mean()*100)
    results['pm_sel'].append((df['s_sel'].mean() - df_test['s_sel'].mean())*100)
    results['pm_cond'].append((df['s_cond'].mean() - df_test['s_cond'].mean())*100)
    results['pm_agg'].append((df['s_agg'].mean() - df_test['s_agg'].mean())*100)
    results['pm_nest'].append((df['s_nest'].mean() - df_test['s_nest'].mean())*100)
    results['pm_oth'].append((df['s_oth'].mean() - df_test['s_oth'].mean())*100)

    g_score = df.groupby(col)[['score', 's_sel', 's_cond', 's_agg', 's_nest', 's_oth']].mean() * 100 - baseline
    for l in iterator:
        results[f'ex_acc_{l}'].append(g_score.loc[l, 'score'])
        results[f'count_{l}'].append(df[df[col] == l].shape[0])
        for c in ['s_sel', 's_cond', 's_agg', 's_nest', 's_oth']:
            results[f'pm_{c[2:]}_{l}'].append(g_score.loc[l, c])
    

sql_gen_hint_top1_low_desc
sql_gen_hint_top1_mid_desc
sql_gen_hint_top1_high_desc
sql_gen_hint_top3_low_desc
sql_gen_hint_top3_mid_desc
sql_gen_hint_top3_high_desc
sql_gen_hint_top1_low_descvt
sql_gen_hint_top1_mid_descvt
sql_gen_hint_top1_high_descvt
sql_gen_hint_top3_low_descvt
sql_gen_hint_top3_mid_descvt
sql_gen_hint_top3_high_descvt


In [81]:
df = pd.DataFrame(results)
# df.set_index(['bo_topk','bo_desc_vt', 'bo_level'], inplace=True)
desc_vt = {'desc': 'BA', 'descvt': 'BA + VT'}
df['bo_desc_vt'] = df['bo_desc_vt'].map(desc_vt)
df['bo_level'] = df['bo_level'].str.capitalize()

idx_cols = ['bo_topk','bo_desc_vt', 'bo_level']
count_cols = ['count', 'count_low', 'count_mid', 'count_high'] if typ == '_c' else ['count', 'count_1', 'count_2', 'count_3+']
r1 = ['ex_acc_low', 'ex_acc_mid', 'ex_acc_high', 'ex_acc'] if typ == '_c' else ['ex_acc_1', 'ex_acc_2', 'ex_acc_3+', 'ex_acc']
r2 = ['pm_sel', 'pm_cond', 'pm_agg', 'pm_nest', 'pm_oth'] if typ == '_c' else ['pm_sel', 'pm_cond', 'pm_agg', 'pm_nest', 'pm_oth']
r3 = ['pm_sel_low', 'pm_sel_mid', 'pm_sel_high'] if typ == '_c' else ['pm_sel_1', 'pm_sel_2', 'pm_sel_3+']
r4 = ['pm_cond_low', 'pm_cond_mid', 'pm_cond_high'] if typ == '_c' else ['pm_cond_1', 'pm_cond_2', 'pm_cond_3+']
r5 = ['pm_agg_low', 'pm_agg_mid', 'pm_agg_high'] if typ == '_c' else ['pm_agg_1', 'pm_agg_2', 'pm_agg_3+']
r6 = ['pm_nest_low', 'pm_nest_mid', 'pm_nest_high'] if typ == '_c' else ['pm_nest_1', 'pm_nest_2', 'pm_nest_3+']
r7 = ['pm_oth_low', 'pm_oth_mid', 'pm_oth_high'] if typ == '_c' else ['pm_oth_1', 'pm_oth_2', 'pm_oth_3+']

In [82]:
with pd.ExcelWriter(proj_path / 'experiments' / 'reports' / f'bo_eval{typ}.xlsx') as writer:
    df.loc[:, idx_cols+count_cols].to_excel(writer, sheet_name='count')

    df1 = df.loc[:, idx_cols+r1].round(2)
    rename_cols = {
        'bo_topk': 'Top-K', 'bo_desc_vt': 'Prompt Type',
        'bo_level': 'Complexity Lv.' if typ == '_c' else 'Table Num.',
        'ex_acc': 'Overall',
    }
    for l in iterator:
        rename_cols[f'ex_acc_{l}'] = f'{l.capitalize()}'
    df1.rename(columns=rename_cols, inplace=True)
    df1.to_excel(writer, sheet_name='ex_acc', index=False)

    df2 = df.loc[:, idx_cols+r2].round(2)
    
    df2.rename(columns={
        'bo_topk': 'Top-K', 'bo_desc_vt': 'Prompt Type',
        'bo_level': 'Complexity Lv.' if typ == '_c' else 'Table Num.',
        'pm_sel': 'Selection',
        'pm_cond': 'Condition',
        'pm_agg': 'Aggregation',
        'pm_nest': 'Nested',
        'pm_oth': 'Others',
    }, inplace=True)
    df2.to_excel(writer, sheet_name='pm', index=False)

    df3 = df.loc[:, idx_cols+r3+r4+r5+r6+r7].round(2)
    rename_cols = {
        'bo_topk': 'Top-K', 'bo_desc_vt': 'Prompt Type',
        'bo_level': 'Complexity Lv.' if typ == '_c' else 'Table Num.',
    }
    for l in iterator:
        for c in ['sel', 'cond', 'agg', 'nest', 'oth']:
            rename_cols[f'pm_{c}_{l}'] = f'{c.capitalize()} {l.capitalize()}'
    df3.rename(columns=rename_cols, inplace=True)
    df3.to_excel(writer, sheet_name='pm_detail', index=False)