In [1]:
import ast, pandas as pd
from llm import call_llm
from sql import run_query

In [2]:
dev_set = pd.read_csv('./dev_set.csv')
dev_set['conversations'] = dev_set['conversations'].apply(lambda x: ast.literal_eval(x))

In [3]:
selected_sample = dev_set.sample(n=1)
print(f"Database Name: `{selected_sample['database'].iloc[0]}`")
print()
print(f"INPUT:\n{selected_sample['conversations'].iloc[0][0]['value']}")
print()
print(f"OUTPUT:\n{selected_sample['conversations'].iloc[0][1]['value']}")

Database Name: `codebase_community`

INPUT:
Based on the SQL db schema given below, you have to answer the question that follows it. Your answer should be a valid, correct SQL query. You are provided with a HINT to generate the SQL query.

SCHEMA: Tables in the database codebase_community: badges, comments, postHistory, postLinks, posts, tags, users, votes
-------------------------
CREATE query for table: badges

CREATE TABLE badges ( Id INTEGER not null primary key, UserId INTEGER null, Name TEXT null, Date DATETIME null, foreign key (UserId) references users (Id) on update cascade on delete cascade )
-------------------------
CREATE query for table: comments

CREATE TABLE comments ( Id INTEGER not null primary key, PostId INTEGER null, Score INTEGER null, Text TEXT null, CreationDate DATETIME null, UserId INTEGER null, UserDisplayName TEXT null, foreign key (PostId) references posts (Id) on update cascade on delete cascade, foreign key (UserId) references users (Id) on update cascade

In [4]:
generated_query = call_llm(selected_sample.iloc[0]['conversations'][0]['value'])
generated_query

"To find the most valuable post in 2010, you can use the following SQL query:\n\n```sql\nSELECT p.Id, u.OwnerDisplayName\nFROM posts p\nJOIN votes v ON p.Id = v.PostId\nWHERE YEAR(p.CreationDate) = 2010\nORDER BY v.VoteTypeId DESC\nLIMIT 1;\n```\n\nThis query joins the `posts` table with the `votes` table on the `PostId` column. It then filters the results to only include posts from 2010 by using the `YEAR` function on the `CreationDate` column. The results are ordered by the `VoteTypeId` in descending order (most valuable first), and since we're ordering for the most valuable post, we can use a limit of 1.\n\nNote that if there are multiple votes with the same maximum value for a particular vote type, this query will only return one of them."

In [6]:
generated_query_output = run_query(selected_sample['database'].iloc[0], generated_query)
generated_query_output

('ERROR: near "To": syntax error', -1)

In [7]:
if generated_query_output[1] == -1:
    schema = selected_sample.iloc[0]['conversations'][0]['value'].split("Based on the SQL db schema given below, you have to answer the question that follows it. Your answer should be a valid, correct SQL query. You are provided with a HINT to generate the SQL query.\n\n")[1].split("QUESTION:")[0].strip()
    query_to_rectify = generated_query_output
    query_error = generated_query_output[0]
    
    query_rectification_prompt = f"You are an expert in rectifying incorrect SQL queries based on db schema and error.\n\nRead schema below:\n\n{schema}\n\nYou need to rectify this query: {query_to_rectify}\n\nIt is giving following error: {query_error}"
    rectified_query = call_llm(query_rectification_prompt)
    if '```sql' in rectified_query:
        try:
            rectified_query = rectified_query.split('```sql')[1].split('```')[0].strip()
        except:
            rectified_query = rectified_query

    rectified_query_output = run_query(selected_sample['database'].iloc[0], rectified_query)

In [8]:
rectified_query

"SELECT c.*\nFROM comments c\nJOIN votes v ON c.Id = v.PostId AND v.VoteTypeId = 2 -- Assuming '2' is the id for upvote\nWHERE v.UserId IS NOT NULL AND c.Score > 0;"

In [9]:
rectified_query_output

([], 0.007)