In [1]:
%load_ext autoreload
%autoreload 2
import json
import sqlparse
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
from src.database import SqliteDatabase
from src.eval import result_eq, check_if_exists_orderby
from src.eval_complexity import eval_all
from src.process_sql import get_schema, Schema
from src.parsing_sql import (
    extract_selection, 
    extract_condition, 
    extract_aggregation, 
    extract_nested_setoperation, 
    extract_others,
    extract_aliases,
)

proj_path = Path('.').resolve()
all_tasks = []

for typ in ['_c', '_t']:
    iterator = ['low', 'mid', 'high'] if typ == '_c' else ['1', '2', '3+']
    for typ2 in ['desc', 'descvt']:        
        for n_retrieval in [1, 3]:
            for level in iterator:
                all_tasks.append(f'sql_gen_hint_top{n_retrieval}_{level}_{typ2}')

In [14]:
from collections import defaultdict

In [None]:
error_infos = {
    'pred_exec': [],
    'result': [],
    'parsing_sql': [],
    'error_samples': set()
}

# filter parsing errors
for task in all_tasks:
    with open(proj_path / 'experiments' / f'{task}.jsonl', 'r') as f:
        for line in tqdm(f, desc=task):
            x = json.loads(line)
            has_error = False
            schema = get_schema(str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite'))
            schema = Schema(schema)
            
            parsed_result = {}
            for s in ['gold', 'pred']:
                try:
                    sql = x[f'{s}_sql']
                    statement = sqlparse.parse(sql.strip())[0]
                    aliases = extract_aliases(statement)
                    selection = extract_selection(statement, aliases, schema)
                    condition = extract_condition(statement)
                    aggregation = extract_aggregation(statement, aliases, schema)
                    nested = extract_nested_setoperation(statement)
                    others = extract_others(statement, aliases, schema)
                    
                    parsed_result[s + '_selection'] = selection
                    parsed_result[s + '_condition'] = condition
                    parsed_result[s + '_aggregation'] = aggregation
                    parsed_result[s + '_nested'] = nested
                    parsed_result[s + '_others'] = {
                        'distinct': others['distinct'], 
                        'order by': others['order by'], 
                        'limit': others['limit']
                    }
                except Exception as e:
                    has_error = True
                    error_infos['parsing_sql'].append((x['sample_id'], s, str(e)))
                    error_infos['error_samples'].add(x['sample_id'])
                    break
            
            if has_error:
                continue

print(f'Parsing SQL errors: {len(error_infos["parsing_sql"])}')

In [16]:
set(error_infos['error_samples'])

{'dev.225',
 'dev.229',
 'dev.311',
 'dev.312',
 'dev.742',
 'dev.743',
 'train.1602',
 'train.167',
 'train.168',
 'train.205',
 'train.2328',
 'train.2500',
 'train.3672',
 'train.4363',
 'train.6014',
 'train.6330',
 'train.897',
 'train.898'}

In [18]:
task = 'sql_gen_hint_top1_low_desc'

In [20]:
# process single task

def process_task(task, error_infos):
    task_results = {
        'sample_id': [],
        'score': [],
        's_sel': [], 's_cond': [], 's_agg': [], 's_nest': [], 's_oth': [],
    }
    with open(proj_path / 'experiments' / f'{task}.jsonl', 'r') as f:
        iterator = tqdm(f)
        for line in iterator:
            x = json.loads(line)
            if x['sample_id'] in error_infos['error_samples']:
                continue

            task_results['sample_id'].append(x['sample_id'])
            # parsing sql
            schema = get_schema(str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite'))
            schema = Schema(schema)
            
            parsed_result = {}
            for s in ['gold', 'pred']:
                sql = x[f'{s}_sql']
                statement = sqlparse.parse(sql.strip())[0]
                aliases = extract_aliases(statement)
                selection = extract_selection(statement, aliases, schema)
                condition = extract_condition(statement)
                aggregation = extract_aggregation(statement, aliases, schema)
                nested = extract_nested_setoperation(statement)
                others = extract_others(statement, aliases, schema)
                
                parsed_result[s + '_selection'] = selection
                parsed_result[s + '_condition'] = condition
                parsed_result[s + '_aggregation'] = aggregation
                parsed_result[s + '_nested'] = nested
                parsed_result[s + '_others'] = {
                    'distinct': others['distinct'], 
                    'order by': others['order by'], 
                    'limit': others['limit']
                }

            # partial & complexity eval
            eval_res = eval_all(parsed_result, k=6)
            task_results['s_sel'].append(eval_res['score']['selection'])
            task_results['s_cond'].append(eval_res['score']['condition'])
            task_results['s_agg'].append(eval_res['score']['aggregation'])
            task_results['s_nest'].append(eval_res['score']['nested'])
            task_results['s_oth'].append(eval_res['score']['others'])
            # Execution
            database = SqliteDatabase(
                str(proj_path / 'data' / 'spider' / 'database' / x['db_id'] / f'{x["db_id"]}.sqlite')
            )
            error_info = ''
            try:
                pred_result = database.execute(x['pred_sql'], rt_pandas=False)
            except Exception as e:
                pred_result = []
                error_info = 'Predction Execution Error:' + str(e)
                score = 0

            try:
                gold_result = database.execute(x['gold_sql'], rt_pandas=False)
            except Exception as e:
                error_info = 'Gold Execution Error:' + str(e)

            if 'Gold Execution Error' in error_info:
                continue
            elif 'Predction Execution Error' in error_info:
                task_results['score'].append(score)
                continue
            else:
                exists_orderby = check_if_exists_orderby(x['gold_sql'])
                score = int(result_eq(pred_result, gold_result, order_matters=exists_orderby))
                task_results['score'].append(score)

    return task_results, error_infos

pred_exec: 29 | result: 588: : 710it [00:10, 65.52it/s] 


In [21]:
pd.DataFrame(task_results)

Unnamed: 0,sample_id,db_id,score,s_sel,s_cond,s_agg,s_nest,s_oth
0,train.56,student_assessment,0,0.000000,1.0,0.0,1.0,0.5
1,train.64,student_assessment,0,0.000000,1.0,1.0,1.0,1.0
2,train.65,student_assessment,0,1.000000,0.0,1.0,1.0,1.0
3,train.76,student_assessment,0,0.000000,1.0,1.0,1.0,0.5
4,train.77,student_assessment,0,0.666667,1.0,1.0,1.0,0.5
...,...,...,...,...,...,...,...,...
687,dev.954,dog_kennels,0,1.000000,0.5,1.0,1.0,1.0
688,dev.958,dog_kennels,0,1.000000,1.0,1.0,1.0,0.5
689,dev.960,dog_kennels,0,1.000000,0.0,1.0,0.0,0.0
690,dev.997,dog_kennels,0,0.000000,1.0,1.0,1.0,0.5


In [11]:
results = {
    'task': [],
    'bo_topk': [], 'bo_level': [], 'bo_desc_vt': [], 
    'ex_acc': [], 'pm_sel': [], 'pm_cond': [], 'pm_agg': [], 'pm_nest': [], 'pm_oth': [],
    'ex_acc_low': [], 'ex_acc_mid': [], 'ex_acc_high': [], 'ex_acc_1': [], 'ex_acc_2': [], 'ex_acc_3+': [],
    'pm_sel_low': [], 'pm_sel_mid': [], 'pm_sel_high': [], 'pm_sel_1': [], 'pm_sel_2': [], 'pm_sel_3+': [],
    'pm_cond_low': [], 'pm_cond_mid': [], 'pm_cond_high': [], 'pm_cond_1': [], 'pm_cond_2': [], 'pm_cond_3+': [],
    'pm_agg_low': [], 'pm_agg_mid': [], 'pm_agg_high': [], 'pm_agg_1': [], 'pm_agg_2': [], 'pm_agg_3+': [],
    'pm_nest_low': [], 'pm_nest_mid': [], 'pm_nest_high': [], 'pm_nest_1': [], 'pm_nest_2': [], 'pm_nest_3+': [],
    'pm_oth_low': [], 'pm_oth_mid': [], 'pm_oth_high': [], 'pm_oth_1': [], 'pm_oth_2': [], 'pm_oth_3+': [],
}



# real eval
for task in all_tasks:
    print(task)
    results['task'].append(task)
    results['bo_topk'].append(int(task.lstrip('sql_gen_hint_top').split('_')[0]))
    results['bo_level'].append(task.lstrip('sql_gen_hint_top').split('_')[1])
    results['bo_desc_vt'].append(task.lstrip('sql_gen_hint_top').split('_')[2])

{'sample_id': 'train.56',
 'db_id': 'student_assessment',
 'question': 'which course has most number of registered students?',
 'rationale': ['We need to find out which course has the most number of registered students.',
  "To do this, we will use the 'Student_Course_Registrations' table, which contains the relationship between students and courses.",
  'We will count the number of students registered for each course using COUNT(student_id).',
  'We will group the results by course_id to get the count for each course.',
  'To find the course with the most students, we will order the results in descending order based on the student count.',
  'Finally, we will limit the results to 1 to get only the course with the highest number of registered students.'],
 'gold_sql': 'SELECT T1.course_name FROM courses AS T1 JOIN student_course_registrations AS T2 ON T1.course_id = T2.course_Id GROUP BY T1.course_id ORDER BY count(*) DESC LIMIT 1',
 'source_tables': ['courses', 'student_course_registr

In [7]:
eval_res

{'score': {'selection': 0.0,
  'condition': 1.0,
  'aggregation': 0.0,
  'nested': 1.0,
  'others': 0.5},
 'complexity': {'selection': [0.2799999999999999, 0.4705882352941176],
  'condition': [0.0, 0.0],
  'aggregation': [0.2799999999999999, 0.2799999999999999],
  'nested': [0.0, 0.0],
  'others': [0.2799999999999999, 0.2799999999999999]}}