In [1]:
import ast, tqdm, pandas as pd
from sql import run_query
from eval_metrics import get_error_distribution, get_output_accuracy

In [2]:
zero_few = pd.read_csv('./dev_set_zero_few_shot.csv')
zero_few['gold_query'] = zero_few['conversations'].apply(lambda x: ast.literal_eval(x)[1]['value'])

In [3]:
result = {
    'zero_shot': [],
    'few_shot': [],
    'gold_query': []
}

In [4]:
for idx, row in tqdm.tqdm(zero_few.iterrows(), desc='Running SQL Queries', total=zero_few.shape[0]):
    gold_output = run_query(row['database'], row['gold_query'])
    zero_shot_output = run_query(row['database'], row['zero_shot'])
    few_shot_output = run_query(row['database'], row['few_shot'])

    result['gold_query'].append(gold_output)
    result['zero_shot'].append(zero_shot_output)
    result['few_shot'].append(few_shot_output)

Running SQL Queries: 100%|██████████| 1534/1534 [05:08<00:00,  4.97it/s] 


In [5]:
result['zero_shot']

[('ERROR: near "cds": syntax error', -1),
 ('ERROR: no such column: T1.Percent (%) Eligible Free (Ages 5-17)', -1),
 ('ERROR: no such table: california_schools.frpm', -1),
 ([('43466 Business Park Drive',)], 0.0037),
 ('ERROR: no such column: Chartered', -1),
 ('ERROR: no such column: cdsextended', -1),
 ('ERROR: no such column: T3.SchoolName', -1),
 ([('(213) 241-1000',)], 0.0006),
 ('ERROR: near "Count": syntax error', -1),
 ('ERROR: near "(": syntax error', -1),
 ('ERROR: near "(": syntax error', -1),
 ('ERROR: near "5": syntax error', -1),
 ('ERROR: no such column: T1.Percent (%) Eligible Free (Ages 5-17)', -1),
 ([('(714) 220-3055',), ('(562) 229-7745',), ('(408) 366-7700',)], 0.004),
 ('ERROR: near "To": syntax error', -1),
 ('ERROR: no such column: T1.District', -1),
 ('ERROR: no such column: T2.DistrictCode', -1),
 ('ERROR: near "To": syntax error', -1),
 ('ERROR: no such column: T1.County', -1),
 ([('(408) 366-7700',)], 0.0007),
 ('ERROR: no such table: california_schools', -1

In [6]:
result['few_shot']

[('ERROR: no such table: works', -1),
 ('ERROR: no such table: app_events', -1),
 ('ERROR: near "School": syntax error', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "To": syntax error', -1),
 ('ERROR: near "*": syntax error', -1),
 ('ERROR: near "Based": syntax error', -1),
 ('ERROR: near "Based": syntax error', -1),
 ('ERROR: no such table: patients', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "To": syntax error', -1),
 ('ERROR: no such table: Users', -1),
 ('ERROR: near "I": syntax error', -1),
 ('ERROR: near "Based": syntax error', -1),
 ('ERROR: no such table: Crime', -1),
 ('ERROR: near "Here": syntax error', -1),
 ('ERROR: near "Based": syntax error', -1),
 ('ERROR: no such table: phone_brand_device_model2', -1),
 ('ERROR: near "To": syntax error', -1),
 ('ERROR: near "Her

In [7]:
result['gold_query']

[([(1.0,)], 0.0072),
 ([(0.043478260869565216,), (0.07042253521126761,), (0.11363636363636363,)],
  0.001),
 ([('93726-5309',),
   ('93628-9602',),
   ('93706-2611',),
   ('93726-5208',),
   ('93706-2819',)],
  0.0013),
 ([('14429 South Downey Avenue',)], 0.0093),
 ([(None,),
   ('(510) 596-8901',),
   (None,),
   ('(510) 686-4131',),
   ('(510) 452-2063',),
   ('(510) 842-1181',),
   ('(510) 748-4008',),
   ('(510) 748-4017',),
   ('(510) 995-4300',),
   ('(510) 748-4314',),
   ('(510) 809-9800',),
   ('(510) 809-9800',),
   ('(510) 300-1340',),
   ('(510) 300-1560',),
   ('(510) 931-7868',),
   ('(510) 543-4124',),
   ('(510) 370-3334',),
   ('(925) 443-1690',),
   ('(510) 635-7170',),
   ('(510) 562-5238',),
   ('(510) 382-9932',),
   ('(510) 562-8225',),
   ('(510) 658-2900',),
   ('(510) 904-6440',),
   ('(510) 893-8701',),
   ('(510) 893-8701',),
   ('(510) 285-7511',),
   ('(510) 893-8700',),
   ('(510) 874-7255',),
   ('(510) 436-5487',),
   ('(510) 992-7800',),
   ('(510) 879-

In [8]:
gold_errors = get_error_distribution(result['gold_query'])
gold_errors

{'syntax': 0, 'column': 0, 'table': 0, 'misc': 0}

In [9]:
zero_shot_errors = get_error_distribution(result['zero_shot'])
zero_shot_errors

('ERROR: no such function: CharterSchool', -1)
('ERROR: incomplete input', -1)
('ERROR: ambiguous column name: CDSCode', -1)
('ERROR: aggregate functions are not allowed in the GROUP BY clause', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: misuse of aggregate: SUM()', -1)
('ERROR: misuse of aggregate: AVG()', -1)
('ERROR: ambiguous column name: id', -1)
('ERROR: ambiguous column name: id', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: no such function: SUBTRACT', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: no such function: YEAR', -1)
('ERROR: no such function: STRIP', -1)
('ERROR: misuse of aggregate function AVG()', -1)
('ERROR: ambiguous column name: CustomerID', -1)


{'syntax': 361, 'column': 358, 'table': 15, 'misc': 17}

In [10]:
print("Execution Accuracy Zero Shot:", round(100 * (zero_few.shape[0] - sum(zero_shot_errors.values())) / zero_few.shape[0], 2), "%")

Execution Accuracy Zero Shot: 51.04 %


In [11]:
few_shot_errors = get_error_distribution(result['few_shot'])
few_shot_errors

('ERROR: no such function: SUBTRACT', -1)
('ERROR: incomplete input', -1)


{'syntax': 1043, 'column': 34, 'table': 425, 'misc': 2}

In [12]:
print("Execution Accuracy Few Shot:", round(100 * (zero_few.shape[0] - sum(few_shot_errors.values())) / zero_few.shape[0], 2), "%")

Execution Accuracy Few Shot: 1.96 %


In [13]:
get_output_accuracy(result['few_shot'], result['gold_query'])

'0.26% Accuracy'

In [14]:
get_output_accuracy(result['zero_shot'], result['gold_query'])

'18.06% Accuracy'

In [15]:
finetuned = pd.read_csv('./dev_set_finetuned.csv')
finetuned['gold_query'] = finetuned['conversations'].apply(lambda x: ast.literal_eval(x)[1]['value'])
finetuned['prediction'] = finetuned['prediction'].apply(lambda x: x.split('<|im_start|>assistant')[1].strip())
finetuned

Unnamed: 0,database,conversations,prediction,gold_query
0,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT MAX(T2.`Free Meal Count (K-12)`) FROM s...,SELECT `Free Meal Count (K-12)` / `Enrollment ...
1,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT Percent (%) Eligible Free (K-12) FROM s...,SELECT `Free Meal Count (Ages 5-17)` / `Enroll...
2,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T1.Zip FROM schools AS T1 INNER JOIN fr...,SELECT T2.Zip FROM frpm AS T1 INNER JOIN schoo...
3,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T1.MailStreet FROM california_schools A...,SELECT T2.MailStreet FROM frpm AS T1 INNER JOI...
4,california_schools,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...,SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...
...,...,...,...,...
1529,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...","SELECT T1.Amount, SUM(T1.Amount) FROM transact...","SELECT SUM(T1.Price) , SUM(IIF(T3.Date = '2012..."
1530,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT DISTINCT T2.Description FROM transactio...,SELECT T2.Description FROM transactions_1k AS ...
1531,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...","SELECT T1.CustomerID, T3.Price, T2.Currency FR...","SELECT T2.CustomerID, SUM(T2.Price / T2.Amount..."
1532,debit_card_specializing,"[{'from': 'human', 'value': 'Based on the SQL ...",SELECT T2.Country FROM transactions_1k AS T1 I...,SELECT T2.Country FROM transactions_1k AS T1 I...


In [16]:
result = {
    'gold_query': [],
    'finetuned': []
}
for idx, row in tqdm.tqdm(finetuned.iterrows(), desc='Running SQL Queries', total=finetuned.shape[0]):
    gold_output = run_query(row['database'], row['gold_query'])
    finetuned_output = run_query(row['database'], row['prediction'])

    result['gold_query'].append(gold_output)
    result['finetuned'].append(finetuned_output)

Running SQL Queries: 100%|██████████| 1534/1534 [04:59<00:00,  5.12it/s] 


In [17]:
finetuned_errors = get_error_distribution(result['finetuned'])
finetuned_errors

('ERROR: no such function: FRPMCount', -1)
('ERROR: ambiguous column name: atom_id', -1)
('ERROR: unrecognized token: "\'dip"', -1)
('ERROR: misuse of aggregate function MIN()', -1)
('ERROR: 1st ORDER BY term does not match any column in the result set', -1)
('ERROR: misuse of aggregate function MIN()', -1)
('ERROR: no such function: T2', -1)
('ERROR: unrecognized token: "\'+\'\'"', -1)
('ERROR: misuse of aggregate function COUNT()', -1)


{'syntax': 35, 'column': 321, 'table': 45, 'misc': 9}

In [18]:
print("Execution Accuracy Finetuned:", round(100 * (finetuned.shape[0] - sum(finetuned_errors.values())) / finetuned.shape[0], 2), "%")

Execution Accuracy Finetuned: 73.27 %


In [19]:
get_output_accuracy(result['finetuned'], result['gold_query'])

'30.12% Accuracy'