In [1]:
import ast, tqdm, pandas as pd

from llm import call_llm
from sql import run_query
from eval_metrics import get_error_distribution, get_output_accuracy

In [2]:
finetuned = pd.read_csv('./dev_set_finetuned.csv')
finetuned['gold_query'] = finetuned['conversations'].apply(lambda x: ast.literal_eval(x)[1]['value'])
finetuned['prediction'] = finetuned['prediction'].apply(lambda x: x.split('<|im_start|>assistant')[1].strip())
finetuned['db_definition'] = finetuned['conversations'].apply(lambda x: ast.literal_eval(x)[0]['value'].split("Based on the SQL db schema given below, you have to answer the question that follows it. Your answer should be a valid, correct SQL query. You are provided with a HINT to generate the SQL query.\n\n")[1].split("QUESTION:")[0].strip())
finetuned['rectified_prediction'] = ''
finetuned

Unnamed: 0,database,conversations,prediction,gold_query,db_definition,rectified_prediction
0,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT MAX(T2.`Free Meal Count (K-12)`) FROM s...,SELECT `Free Meal Count (K-12)` / `Enrollment ...,SCHEMA: Tables in the database california_scho...,
1,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT Percent (%) Eligible Free (K-12) FROM s...,SELECT `Free Meal Count (Ages 5-17)` / `Enroll...,SCHEMA: Tables in the database california_scho...,
2,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T1.Zip FROM schools AS T1 INNER JOIN fr...,SELECT T2.Zip FROM frpm AS T1 INNER JOIN schoo...,SCHEMA: Tables in the database california_scho...,
3,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T1.MailStreet FROM california_schools A...,SELECT T2.MailStreet FROM frpm AS T1 INNER JOI...,SCHEMA: Tables in the database california_scho...,
4,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...,SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...,SCHEMA: Tables in the database california_scho...,
...,...,...,...,...,...,...
1529,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...","SELECT T1.Amount, SUM(T1.Amount) FROM transact...","SELECT SUM(T1.Price) , SUM(IIF(T3.Date = '2012...",SCHEMA: Tables in the database debit_card_spec...,
1530,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT DISTINCT T2.Description FROM transactio...,SELECT T2.Description FROM transactions_1k AS ...,SCHEMA: Tables in the database debit_card_spec...,
1531,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...","SELECT T1.CustomerID, T3.Price, T2.Currency FR...","SELECT T2.CustomerID, SUM(T2.Price / T2.Amount...",SCHEMA: Tables in the database debit_card_spec...,
1532,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T2.Country FROM transactions_1k AS T1 I...,SELECT T2.Country FROM transactions_1k AS T1 I...,SCHEMA: Tables in the database debit_card_spec...,


In [None]:
result = {
    'gold_query': [],
    'finetuned': []
}
for idx, row in tqdm.tqdm(finetuned.iterrows(), desc='Running SQL Queries', total=finetuned.shape[0]):
    gold_output = run_query(row['database'], row['gold_query'])
    result['gold_query'].append(gold_output)

    finetuned_output = run_query(row['database'], row['prediction'])
    if finetuned_output[1] == -1:
        schema = row['db_definition']
        query_to_rectify = row['prediction']
        query_error = finetuned_output[0]
        
        query_rectification_prompt = f"You are an expert in rectifying incorrect SQL queries based on db schema and error.\n\nRead schema below:\n\n{schema}\n\nYou need to rectify this query: {query_to_rectify}\n\nIt is giving following error: {query_error}"
        rectified_query = call_llm(query_rectification_prompt)
        if '```sql' in rectified_query:
            try:
                rectified_query = rectified_query.split('```sql')[1].split('```')[0].strip()
            except:
                rectified_query = rectified_query

        finetuned['rectified_prediction'][idx] = rectified_query
        finetuned_output = run_query(row['database'], rectified_query)
        result['finetuned'].append(finetuned_output)
    else:
        result['finetuned'].append(finetuned_output)

Running SQL Queries: 100%|██████████| 1534/1534 [52:27<00:00,  2.05s/it] 


In [26]:
finetuned_errors = get_error_distribution(result['finetuned'])
finetuned_errors

('ERROR: no such function: FRPMCount', -1)
('ERROR: no such function: Enrollment', -1)
('ERROR: table district already exists', -1)
('ERROR: Incorrect number of bindings supplied. The current statement uses 1, and there are 0 supplied.', -1)
('ERROR: attempt to write a readonly database', -1)
('ERROR: ambiguous column name: molecule_id', -1)
('ERROR: attempt to write a readonly database', -1)
('ERROR: attempt to write a readonly database', -1)
('ERROR: table badges already exists', -1)
('ERROR: ambiguous column name: id', -1)
('ERROR: no such function: lapTimes', -1)
('ERROR: SELECTs to the left and right of UNION ALL do not have the same number of result columns', -1)
('ERROR: table Examination already exists', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: unrecognized token: ":"', -1)
('ERROR: attempt to write a readonly database', -1)
('ERROR: misuse of aggregate function COUNT()', -1)


{'syntax': 27, 'column': 181, 'table': 13, 'misc': 17}

In [31]:
print("Execution Accuracy Finetuned with Self Correction Loop:", round(100 * (finetuned.shape[0] - sum(finetuned_errors.values())) / finetuned.shape[0], 2), "%")

Execution Accuracy Finetuned with Self Correction Loop: 84.49 %


In [28]:
get_output_accuracy(result['finetuned'], result['gold_query'])

'32.14% Accuracy'

In [29]:
finetuned.to_csv('./dev_set_finetuned_with_self_correction_loop.csv')