In [1]:
import sqlite3
import numpy as np
import time
import pickle
import json
from func_timeout import func_timeout, FunctionTimedOut
from concurrent.futures import ProcessPoolExecutor, as_completed
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

with open("/data/home/vkropoti/sql_data/mini-dev-index", "rb") as fp:   # Unpickling
    mini_dev_index = np.array(pickle.load(fp))

def load_json(dir):
    with open(dir, "r") as j:
        contents = json.loads(j.read())
    return contents
    
def generate_bd_list(path_dev_json):


    data = load_json(path_dev_json)
    bd_list = []
    for k in range(len(data)):
    # for k in mini_dev_index:
        db = data[k]['db_id']
        bd_list.append(db)

    return  bd_list

def generate_sql_gt_list(path_dev_json):

    data = load_json(path_dev_json)
    sql_gt_list = []
    for k in range(len(data)):
    # for k in mini_dev_index:
        gt_sql = data[k]['SQL']
        sql_gt_list.append(gt_sql)

    return  sql_gt_list
    
def calculate_ex(predicted_res, ground_truth_res):
    res = 0
    if set(predicted_res) == set(ground_truth_res):
        res = 1
    return res

def sql_worker(args):
    i, path_sql_dbs, db, sql_gt, sql_predict = args
    if len(sql_predict) > 512:
        return i, 0, 0
    
    try:
        
        def execute_query(path_sql_dbs,db,sql_predict,sql_gt):
            conn = sqlite3.connect(f'{path_sql_dbs}/{db}/{db}.sqlite')
            cursor = conn.cursor()
            with conn:
                cursor.execute(sql_predict)
                pred = cursor.fetchall()
                
                cursor.execute(sql_gt)
                real = cursor.fetchall()
            cursor.close()
            conn.close()
            return pred, real

        
        results_pred, results_gt = func_timeout(5, execute_query, args=(path_sql_dbs, db, sql_gt, sql_predict,))
        # results_gt = func_timeout(5, execute_query, args=(sql_gt,))
        
        
        return i, calculate_ex(results_pred, results_gt), 1
    
    except (FunctionTimedOut, Exception) as e:
        # print(e)
        return i, 0, 0

def calculate_ex_values(path_sql_dbs, bd_list, sql_gt_list, sql_predict):
    result = [0] * len(bd_list)
    executed = [0] * len(bd_list)
    
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                sql_worker,
                (i, path_sql_dbs, db, sql_gt_list[i], sql_predict[i])
            )
            for i, db in enumerate(bd_list)
            if len(sql_predict[i]) <= 512
        ]
        
        for future in as_completed(futures):
            try:
                i, ex_val, exec_flag = future.result(timeout=10)
                result[i] = ex_val
                executed[i] = exec_flag
            except Exception as e:
                continue
    
    return result, executed
    
def print_and_save(name, path_sql_dbs, bd_list, sql_gt_list, sql_predict, path_to_save_scores=None, path_to_save_executed=None):
    result, executed = calculate_ex_values(path_sql_dbs, bd_list, sql_gt_list, sql_predict)

    print(f"{name} DEV Финальный результат EX: {np.mean(result)*100:.2f}")
    print(f"{name} процент запросов, которые успешно выполнились DEV: {np.mean(executed)*100:.2f}%")
    return result, executed

In [2]:
path_dev_json = '/home/vkropoti/vllm/dev.json'
# '/data/home/vkropoti/sql_data/dev_spider.json'
# '/home/vkropoti/vllm/dev.json'
path_sql_dbs = '/data/home/vkropoti/sql_data/dev_databases/'
# '/data/home/vkropoti/sql_data/data_bases_spider/'
# '/data/home/vkropoti/sql_data/dev_databases/'
bd_list = generate_bd_list(path_dev_json)
sql_gt_list = generate_sql_gt_list(path_dev_json)

In [3]:
# sql_gt_list

In [4]:
# bd_list

In [53]:
model = "seeklhy:OmniSQL-7B"
# "DeepCoder-14B-Preview"
# "Qwen3-14B"
# model_add = "MiMo-7B-RL"
# type_cot =  'all_new'
version = 2
path = f"/data/home/vkropoti/sql_llm_answers/bird_dev_m_schema/sql_all_{model}-reasoning-v{version}"

In [54]:
with open(path, "rb") as fp:   # Unpickling
    sql_predict = np.array(pickle.load(fp))

In [55]:
print(sql_predict[0])

To solve this problem, we need to determine the highest eligible free rate for K-12 students in the schools located in Alameda County. The eligible free rate for K-12 is calculated as the `Free Meal Count (K-12)` divided by the `Enrollment (K-12)`. Here is the step-by-step approach to construct the SQL query:

1. **Identify the relevant table**: The `frpm` table contains the necessary fields to calculate the eligible free rate for K-12 students.
2. **Filter by County**: We need to select only the records where the `County Name` is 'Alameda'.
3. **Calculate the eligible free rate**: This is done by dividing `Free Meal Count (K-12)` by `Enrollment (K-12)`.
4. **Find the highest eligible free rate**: We need to use the `MAX` function to get the highest eligible free rate.
5. **Ensure the calculation is correct**: Make sure to avoid division by zero by checking that `Enrollment (K-12)` is not zero.

Given these steps, the query can be simplified to avoid unnecessary joins and calculations.

In [56]:
_,_ = print_and_save(model,path_sql_dbs,bd_list,sql_gt_list,sql_predict)

seeklhy:OmniSQL-7B DEV Финальный результат EX: 0.00
seeklhy:OmniSQL-7B процент запросов, которые успешно выполнились DEV: 0.00%


In [57]:
def parse_sql(x):
    try:
        answer = x.split('```')[1].strip()
        if len(answer)==0:
            answer = x.split('```')[2].strip()
    except:
        answer = x
    try:
        if answer[0]=='s' and answer[1]=='q' and answer[2]=='l':
            return answer[3:].strip()
        else:
            return answer
    except:
        print(answer)

In [58]:
%%time
parsed_arr = [parse_sql(x) for x in sql_predict]

CPU times: user 38.5 ms, sys: 83 μs, total: 38.5 ms
Wall time: 38.4 ms


In [59]:
parsed_arr[27]

"SELECT\n    s.sname AS SchoolName,\n    AVG(ss.AvgScrWrite) AS AvgWritingScore,\n    s.Phone AS CommunicationNumber\nFROM\n    schools s\nJOIN\n    satscores ss ON s.CDSCode = ss.cds\nWHERE\n    (STRFTIME('%Y', s.OpenDate) > '1991') OR (STRFTIME('%Y', s.ClosedDate) < '2000')\nGROUP BY\n    s.sname,\n    s.Phone\nORDER BY\n    AvgWritingScore DESC;"

In [37]:
result1, executed = print_and_save(model,path_sql_dbs,bd_list,sql_gt_list,parsed_arr)

seeklhy:OmniSQL-7B DEV Финальный результат EX: 53.52
seeklhy:OmniSQL-7B процент запросов, которые успешно выполнились DEV: 91.26%


In [45]:
result2, executed = print_and_save(model,path_sql_dbs,bd_list,sql_gt_list,parsed_arr)

seeklhy:OmniSQL-7B DEV Финальный результат EX: 53.32
seeklhy:OmniSQL-7B процент запросов, которые успешно выполнились DEV: 90.03%


In [60]:
result3, executed = print_and_save(model,path_sql_dbs,bd_list,sql_gt_list,parsed_arr)

seeklhy:OmniSQL-7B DEV Финальный результат EX: 55.28
seeklhy:OmniSQL-7B процент запросов, которые успешно выполнились DEV: 90.03%


In [61]:
np.mean(np.max([result1,result2,result3],axis=0))

np.float64(0.6740547588005215)

In [36]:
result[27]

0

In [147]:
print(parsed_arr[0])

SELECT MAX(`Free Meal Count (K-12` / `Enrollment (K-12`) FROM frpm WHERE `County Name` = 'Alameda';


In [145]:
result[0:10]

[0, 0, 0, 1, 0, 1, 1, 1, 1, 1]

In [155]:
print(parsed_arr[1])

SELECT ("Free Meal Count (Ages 5-17)" / "Enrollment (Ages 5-17)") AS eligible_rate
FROM frpm
JOIN schools ON frpm.CDSCode = schools.CDSCode
WHERE schools.GSserved = '9-12'
ORDER BY eligible_rate ASC
LIMIT 3;


In [301]:
parsed_arr[30:40]

['SELECT schools.City, frpm.`Enrollment (K-12)` FROM frpm JOIN schools ON frpm.CDSCode = schools.CDSCode ORDER BY frpm.`Enrollment (K-12)` ASC LIMIT 5;',
 np.str_('SELECT (`Free Meal Count (K-12)` / `Enrollment (K-12)`) AS eligible_free_rate FROM frpm ORDER BY `Enrollment (K-12)` DESC LIMIT 2 OFFSET 9;'),
 "SELECT schools.School, frpm.`FRPM Count (K-12)`, frpm.`Enrollment (K-12)`, \n       (frpm.`FRPM Count (K-12)` / frpm.`Enrollment (K-12)` * 100) AS MealRate\nFROM frpm\nJOIN schools ON frpm.CDSCode = schools.CDSCode\nWHERE schools.SOC = '66'\nORDER BY frpm.`FRPM Count (K-12)` DESC\nLIMIT 5;",
 'SELECT schools.Website, schools.School FROM schools\nJOIN frpm ON schools.CDSCode = frpm.CDSCode\nWHERE frpm.`Free Meal Count (Ages 5-17)` BETWEEN 1900 AND 2000;',
 "SELECT ROUND(`Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)`, 2) AS FreeRate FROM frpm JOIN schools ON frpm.CDSCode = schools.CDSCode WHERE `AdmFName1` = 'Kacey' AND `AdmLName1` = 'Gibson';",
 np.str_('SELECT schools.AdmE

In [165]:
data = load_json(path_dev_json)

In [166]:
k=1

In [167]:
data[k]

{'question_id': 1,
 'db_id': 'california_schools',
 'question': 'Please list the lowest three eligible free rates for students aged 5-17 in continuation schools.',
 'evidence': 'Eligible free rates for students aged 5-17 = `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)`',
 'SQL': "SELECT `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` FROM frpm WHERE `Educational Option Type` = 'Continuation School' AND `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` IS NOT NULL ORDER BY `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` ASC LIMIT 3",
 'difficulty': 'moderate'}

In [168]:
print(sql_gt_list[k])

SELECT `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` FROM frpm WHERE `Educational Option Type` = 'Continuation School' AND `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` IS NOT NULL ORDER BY `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` ASC LIMIT 3


In [192]:
new_answer = """SELECT TOP 3 
       frpm.Percent (%) Eligible FRPM (Ages 5-17) 
FROM frpm 
JOIN schools 
ON frpm.CDSCode = schools.CDSCode
WHERE schools.School = 'Continuation'
ORDER BY Percent (%) Eligible FRPM (Ages 5-17) DESC;

"""

In [193]:
def execute_query(path_sql_dbs,db,sql_predict,sql_gt):
        conn = sqlite3.connect(f'{path_sql_dbs}/{db}/{db}.sqlite')
        cursor = conn.cursor()
        with conn:
            cursor.execute(sql_predict)
            pred = cursor.fetchall()
            
            cursor.execute(sql_gt)
            real = cursor.fetchall()
        cursor.close()
        conn.close()
        return pred, real

In [194]:
results_pred, results_gt = func_timeout(5, execute_query, args=(path_sql_dbs, bd_list[k], new_answer,sql_gt_list[k] ,))

OperationalError: near "3": syntax error

In [188]:
calculate_ex(results_pred,results_gt)

0

In [183]:
print(results_pred)

[]


In [184]:
print(results_gt)

[(0.043478260869565216,), (0.07042253521126761,), (0.11363636363636363,)]


In [276]:
results_pred, results_gt = func_timeout(5, execute_query, args=(path_sql_dbs, bd_list[k], parsed_arr[k],sql_gt_list[k] ,))

In [277]:
print(results_pred)
print(results_gt)

[(None, 'Roy A. Johnson High', '01611500130047'), (None, 'Young Adult Program', '01611760127233'), (None, 'Acalanes Center for Independent Study', '07616300107524')]
[(0.043478260869565216,), (0.07042253521126761,), (0.11363636363636363,)]


In [43]:
print(parse_sql(sql_predict[30]))

SELECT major.major_name FROM member JOIN major ON member.link_to_major = major.major_id WHERE member.first_name = 'Angela' AND member.last_name = 'Sanders';


In [27]:
print(sql_predict[5])

To find the gas consumption peak month for SME customers in 2013, I'll join the `customers` and `yearmonth` tables, filter for the year 2013 and SME segment, then calculate the total consumption per month and find the maximum.

```
-- Thoughts: Join yearmonth and customers, filter for 2013 and SME, group by month, find max consumption.
-- SQL Query:
SELECT substr(y.Date, 1, 4) AS Year, substr(y.Date, 5, 2) AS Month, sum(y.Consumption) AS TotalConsumption
FROM yearmonth y
JOIN customers c ON y.CustomerID = c.CustomerID
WHERE substr(y.Date, 1, 4) = '2013' AND c.Segment = 'SME'
GROUP BY substr(y.Date, 1, 4), substr(y.Date, 5, 2)
HAVING Year = '2013'
ORDER BY TotalConsumption DESC
LIMIT 1;
```

The query extracts the year and month from the Date string, sums the consumption for each month, and returns the month with the highest total, which is the peak consumption month for SMEs in 2013.
