In [1]:
import warnings
warnings.filterwarnings('ignore')

### Prepare LLM

In [2]:
import json
import os
from typing import Annotated
from autogen import ConversableAgent, initiate_chats
from app.core.tools.semantic_search_tools import SemanticSearchTool

os.environ['OPENAI_API_KEY'] = 'your-openai-api-key'
llm_config = {"model": "gpt-3.5-turbo"}

# Notes: you can expand n_results to make sure the search results cover final tables, e.g., n_results=50
db_schema_search_tool = SemanticSearchTool(n_results=10)

### Prepare experiment data

In [3]:
from spider_env import SpiderEnv

spider = SpiderEnv(cache_dir='spider')

# select a random question in Spider
observation, info = spider.reset()
question = observation["instruction"]
print('question:', question)
print('gold query:', info['gold_query'])
print('gold result:', info['gold_result'])

Loading cached Spider dataset from spider
Schema file not found for spider/spider/database/twitter_1
Schema file not found for spider/spider/database/company_1
Schema file not found for spider/spider/database/chinook_1
Schema file not found for spider/spider/database/flight_4
Schema file not found for spider/spider/database/small_bank_1
Schema file not found for spider/spider/database/epinions_1
Schema file not found for spider/spider/database/icfp_1
question: Find the famous titles of artists that do not have any volume.
gold query: SELECT Famous_Title FROM artist WHERE Artist_ID NOT IN(SELECT Artist_ID FROM volume)
gold result: [('Antievangelistical Process (re-release)',), ('Antithesis of All Flesh',)]


### Create Agents

In [4]:
with open('app/core/agents/autogen.json') as f:
    agent_config = json.load(f)

search_agent = ConversableAgent(
    **agent_config['assistant_dba'],
    llm_config=llm_config,
    human_input_mode='NEVER'
)
sql_writer_agent = ConversableAgent(
    **agent_config['senior_sql_writer'],
    llm_config=llm_config,
    human_input_mode='NEVER'
)
qa_agent = ConversableAgent(
    **agent_config['senior_qa_engineer'],
    llm_config=llm_config,
    human_input_mode='NEVER'
)

# user proxy
user_proxy = ConversableAgent(
    name="Admin",
    system_message="Give the question, and send instructions to SQL writer to generate a sql query script.",
    code_execution_config=False,
    llm_config=llm_config,
    human_input_mode='NEVER'
)

### Register Tools

In [5]:
@search_agent.register_for_llm(description='Function for searching relevant database/table schemas')
@user_proxy.register_for_execution()
def semantic_search(
    question: Annotated[str, 'A question'] 
) -> Annotated[str, 'Result of relevant table schemas to the question']:
    return db_schema_search_tool(question)

### Prepare SequenceChat

In [6]:
chats = [
    {
        "sender": user_proxy,
        "recipient": search_agent,
        "message": f'Find the relevant table schemas to the question: {question}',
        "summary_method": "reflection_with_llm",
        "max_turns": 2,
        "clear_history" : True
    },
    {
        "sender": search_agent,
        "recipient": sql_writer_agent,
        "message": f'Based on the table schemas, write a SQL query script to answer the question: {question}',
        "summary_method": "reflection_with_llm",
        "summary_args": {
            "summary_prompt" : "Return the target database and SQL query script into as JSON object only, DO NOT explain reason: "
                             "{'database': '', 'sql': ''}",
        },
        "max_turns": 1,
        "clear_history" : False
    },
    {
        "sender": sql_writer_agent,
        "recipient": qa_agent,
        "message": f'Review the SQL query script to be sure it can answer the question: {question}',
        "max_turns": 1,
        "summary_method": "reflection_with_llm",
        "summary_args": {
            "summary_prompt" : "If the SQL query script has to be adjusted. Fix the SQL query script."
                             "Return the target database and SQL query script into as JSON object only, DO NOT explain reason: "
                             "{\"database\": \"\", \"sql\": \"\"}",
        },
    },
]

### Initiate Chat

In [7]:
chat_results = initiate_chats(chats)

[34m
********************************************************************************[0m
[34mStarting a new chat....[0m
[34m
********************************************************************************[0m
[33mAdmin[0m (to Assistant_Search_Engineer):

Find the relevant table schemas to the question: Find the famous titles of artists that do not have any volume.

--------------------------------------------------------------------------------
[33mAssistant_Search_Engineer[0m (to Admin):

[32m***** Suggested tool call (call_LukfK6QZPFf4i3jHDfsJNZuo): semantic_search *****[0m
Arguments: 
{"question":"Find the famous titles of artists that do not have any volume."}
[32m********************************************************************************[0m

--------------------------------------------------------------------------------
[35m
>>>>>>>> EXECUTING FUNCTION semantic_search...[0m
[33mAdmin[0m (to Assistant_Search_Engineer):

[33mAdmin[0m (to Assistant_Search_En

In [8]:
print(chat_results[-1].summary)

{"database": "music_4", "sql": "SELECT DISTINCT a.Famous_Title FROM artist a LEFT JOIN volume v ON a.Artist_ID = v.Artist_ID WHERE v.Volume_ID IS NULL AND a.Famous_Title IS NOT NULL AND v.Artist_ID IS NULL;"}


### Compare with gold result

In [9]:
import sqlite3

summary = json.loads(chat_results[-1].summary)
con = sqlite3.connect(f"spider/spider/database/{summary['database']}/{summary['database']}.sqlite")
cursor = con.cursor()
cursor.execute(summary['sql'])
results = cursor.fetchall()
print(results)

[('Antievangelistical Process (re-release)',), ('Antithesis of All Flesh',)]


In [10]:
print(info['gold_result'])

[('Antievangelistical Process (re-release)',), ('Antithesis of All Flesh',)]
