In [19]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.chat_models import ChatOpenAI

# adding Folder_2 to the system path
import sys
sys.path.insert(0, '../')
from common.utils import Utils


In [20]:

#db = SQLDatabase.from_uri("mysql+mysqlconnector://xxx:xxx@localhost:3306/sys")
db = SQLDatabase.from_uri('postgresql+psycopg2://test:test@localhost/test')


In [21]:
llm = ChatOpenAI(model='gpt-4-0613', temperature=0, openai_api_key = Utils.get_openai_key())
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [55]:
"""SQL agent."""
from typing import Any, Dict, List, Optional, Sequence

from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
    SQL_FUNCTIONS_SUFFIX,
    SQL_PREFIX,
    SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage
from langchain.tools import BaseTool



In [75]:
from time import sleep

In [99]:

def create_sql_agent(
    llm: BaseLanguageModel,
    toolkit: SQLDatabaseToolkit,
    agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    callback_manager: Optional[BaseCallbackManager] = None,
    prefix: str = SQL_PREFIX,
    suffix: Optional[str] = None,
    format_instructions: str = FORMAT_INSTRUCTIONS,
    input_variables: Optional[List[str]] = None,
    top_k: int = 10,
    max_iterations: Optional[int] = 15,
    max_execution_time: Optional[float] = None,
    early_stopping_method: str = "force",
    verbose: bool = False,
    agent_executor_kwargs: Optional[Dict[str, Any]] = None,
    extra_tools: Sequence[BaseTool] = (),
    **kwargs: Any,
) -> AgentExecutor:
    print('**' * 20)
    sleep(1)
    """Construct an SQL agent from an LLM and tools."""
    tools = toolkit.get_tools() + list(extra_tools)
    prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
    agent: BaseSingleActionAgent
    if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
        prompt = ZeroShotAgent.create_prompt(
            tools,
            prefix=prefix,
            suffix=suffix or SQL_SUFFIX,
            format_instructions=format_instructions,
            input_variables=input_variables,
        )
        print(prompt)
        sleep(1)
        llm_chain = LLMChain(
            llm=llm,
            prompt=prompt,
            callback_manager=callback_manager,
        )
        tool_names = [tool.name for tool in tools]
        print('**' * 20)
        sleep(1)
        agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
        print('**' * 20)
        sleep(1)

    elif agent_type == AgentType.OPENAI_FUNCTIONS:
        messages = [
            SystemMessage(content=prefix),
            HumanMessagePromptTemplate.from_template("{input}"),
            AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
        input_variables = ["input", "agent_scratchpad"]
        _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
        print(_prompt)
        agent = OpenAIFunctionsAgent(
            llm=llm,
            prompt=_prompt,
            tools=tools,
            callback_manager=callback_manager,
            **kwargs,
        )
    else:
        raise ValueError(f"Agent type {agent_type} not supported at the moment.")

    return AgentExecutor.from_agent_and_tools(
        agent=agent,
        tools=tools,
        callback_manager=callback_manager,
        verbose=verbose,
        max_iterations=max_iterations,
        max_execution_time=max_execution_time,
        early_stopping_method=early_stopping_method,
        **(agent_executor_kwargs or {}),
    )

## 1 Agent with No Support Vector Store

In [101]:
agent_executor_1 = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    suffix="""When giving the final answer, also append the final SQL and explain it."""
)


****************************************
input_variables=['agent_scratchpad', 'input'] input_types={'agent_scratchpad': typing.List[typing.Union[langchain.schema.messages.AIMessage, langchain.schema.messages.HumanMessage, langchain.schema.messages.ChatMessage, langchain.schema.messages.SystemMessage, langchain.schema.messages.FunctionMessage]]} messages=[SystemMessage(content='You are an agent designed to interact with a SQL database.\nGiven an input question, create a syntactically correct postgresql query to run, then look at the results of the query and return the answer.\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.\nYou can order the results by a relevant column to return the most interesting examples in the database.\nNever query for all the columns from a specific table, only ask for the relevant columns given the question.\nYou have access to tools for interacting with the database.\nOnly use the bel

In [98]:
print("""input_variables=['agent_scratchpad', 'input'] template='You are an agent designed to interact with a SQL database.\nGiven an input question, create a syntactically correct postgresql query to run, then look at the results of the query and return the answer.\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.\nYou can order the results by a relevant column to return the most interesting examples in the database.\nNever query for all the columns from a specific table, only ask for the relevant columns given the question.\nYou have access to tools for interacting with the database.\nOnly use the below tools. Only use the information returned by the below tools to construct your final answer.\nYou MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n\nDO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n\nIf the question does not seem related to the database, just return "I don\'t know" as the answer.\n\n\nsql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column \'xxxx\' in \'field list\', use sql_db_schema to query the correct table fields.\nsql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: \'table1, table2, table3\'\nsql_db_list_tables: Input is an empty string, output is a comma separated list of tables in the database.\nsql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!\n\nUse the following format:\n\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can repeat N times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n\nBegin!\n\nQuestion: {input}\nThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.\n{agent_scratchpad}'
""")

input_variables=['agent_scratchpad', 'input'] template='You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct postgresql query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question d

In [48]:

response = agent_executor_1.run("How many employees are there whose departments are in America?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3mcountries, departments, employees, locations, t_stock[0m
Thought:[32;1m[1;3mI can query the "departments" table to find the departments of employees. I should also query the "locations" table to find the location of each department.
Action: sql_db_schema
Action Input: "departments, locations"[0m
Observation: [33;1m[1;3m
CREATE TABLE departments (
	department_id INTEGER NOT NULL, 
	department_name VARCHAR(30) NOT NULL, 
	manager_id INTEGER, 
	location_id INTEGER, 
	CONSTRAINT dept_id_pk PRIMARY KEY (department_id)
)

/*
3 rows from departments table:
department_id	department_name	manager_id	location_id
10	Administration	200	1700
20	Marketing	201	1800
30	Purchasing	114	1700
*/


CREATE TABLE locations (
	location_id INTEGER NOT NULL, 
	street_address VARCHAR(40), 
	postal_code VARCHAR(12), 
	city VARCHAR(30) NOT NULL, 
	state_province VARCHAR(25

In [49]:
print(response)

{'input': "How many employees's departments are in america", 'output': "There are 23 employees' departments in America."}


## 2 Agent With Vector Store Search For Similar Question and SQL Pair

### 2.1 Create Vector Store of Quest SQl Pairs

In [27]:
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document

embeddings =  OpenAIEmbeddings(openai_api_key = Utils.get_openai_key())

few_shotd = few_shots = {'Find all employees whose departments are in America': "SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.department_id WHERE d.location_id IN (SELECT location_id FROM locations WHERE country_id = 'US');"
                        }
few_shot_docs = [
    Document(page_content=question, metadata={"sql_query": few_shots[question]})
    for question in few_shots.keys()
]

#vector_db = FAISS.from_documents(few_shot_docs, embeddings)
#vector_db.save_local("similar_example_store_index")
vector_db = FAISS.load_local("similar_example_store_index", embeddings)
retriever = vector_db.as_retriever()

from langchain.agents.agent_toolkits import create_retriever_tool

tool_description = """
This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""

retriever_tool_example = create_retriever_tool(
    retriever, name="sql_get_similar_examples", description=tool_description
)

In [31]:
custom_tool_list = [retriever_tool_example]

custom_suffix_example = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables.

When giving the final answer, also append the final SQL and explain it.
"""

agent_2 = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    extra_tools=custom_tool_list,
    suffix=custom_suffix_example,
)

`agent_scratchpad` should be a variable in prompt.input_variables. Did not find it, so adding it at the end.


In [None]:
response = agent_2.run("How many employees are there, whose departments are in america")

In [30]:
print(response)

There are 69 employees whose departments are located in America.

The SQL query used to get this answer is:
```sql
SELECT COUNT(*) 
FROM employees e 
JOIN departments d ON e.department_id = d.department_id 
WHERE d.location_id IN (SELECT location_id FROM locations WHERE country_id = 'US');
```
This query first joins the employees and departments tables on the department_id. It then filters for departments that are located in America (represented by the country_id 'US'). The count of these employees is then returned.


## 3. Agent with Vector Store Checking Filtering Condition

In [5]:
import ast
import re


def run_query_save_results(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


country_name = run_query_save_results(db, "SELECT country_name From countries")
country_id = run_query_save_results(db, "SELECT country_id From countries")
department_name = run_query_save_results(db, "SELECT  department_name FROM departments")

In [14]:
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS


texts = country_name + country_id + department_name

embeddings = OpenAIEmbeddings(openai_api_key = Utils.get_openai_key())

#embedding
#vector_db_1 = FAISS.from_texts(texts, embeddings)
#vector_db_1.save_local("name_search_store_index")
vector_db_1 = FAISS.load_local("name_search_store_index", embeddings)

#construct custom tool
retriever_1 = vector_db_1.as_retriever()
retriever_tool_name = create_retriever_tool(
    retriever_1,
    name="name_search",
    description="use to learn how a piece of data is actually written, can be from names, surnames addresses etc",
)
custom_tool_list = [retriever_tool_name]

In [17]:
from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI

# db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model_name="gpt-4", temperature=0, openai_api_key = Utils.get_openai_key())

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
If a query statement asks for me to filter based on proper nouns, I should first check the spelling using the name_search tool.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables.

When giving the final answer, also append the final SQL to it and explain the SQL.
"""

agent_3 = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    extra_tools=custom_tool_list,
    suffix=custom_suffix,
)

In [18]:
print(agent_3.run("How many employees's departments are in British"))



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `name_search` with `British`


[0m[33;1m[1;3m[Document(page_content='UK'), Document(page_content='US'), Document(page_content='United Kingdom'), Document(page_content='Australia')][0m[32;1m[1;3m
Invoking: `sql_db_list_tables` with ``
responded: The term "British" could refer to the United Kingdom. Let's proceed with this assumption and check the tables in the database.

[0m[38;5;200m[1;3mcountries, departments, employees, locations, t_stock[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `employees, departments, locations`
responded: The relevant tables for this query seem to be "employees", "departments", and "locations". Let's check the schema of these tables to understand their structure and relationships.

[0m[33;1m[1;3m
CREATE TABLE departments (
	department_id INTEGER NOT NULL, 
	department_name VARCHAR(30) NOT NULL, 
	manager_id INTEGER, 
	location_id INTEGER, 
	CONSTRAINT dept_id_pk PRIMARY KEY 

## 4. Agent with Vector Stores Support Similar Query Search and Filtering Condition Checking

In [7]:

custom_suffix = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables.

If a user asks for me to filter based on proper nouns, I should first check the spelling using the name_search tool.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables.

When giving the final answer, also append the final SQL to it and explain the SQL in natural languages.
"""

custom_tool_list = [retriever_tool_name, retriever_tool_example]

agent_compose = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    extra_tools=custom_tool_list,
    suffix=custom_suffix,
)

In [51]:
agent_compose.run("How many employees's departments are in America")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_get_similar_examples` with `How many employees's departments are in America`


[0m[38;5;200m[1;3m[Document(page_content='Find all employees whose departments are in America', metadata={'sql_query': "SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.department_id WHERE d.location_id IN (SELECT location_id FROM locations WHERE country_id = 'US');"})][0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.department_id WHERE d.location_id IN (SELECT location_id FROM locations WHERE country_id = 'US');`
responded: The similar example provided gives me a good starting point to construct the SQL query. The query is counting the number of employees whose departments are located in America. The 'US' country_id is used to represent America in the database. Let's construct the query.

[0m[36;1m[1;3mSELECT COUNT(*) 

"There are 69 employees whose departments are located in America.\n\nThe SQL query used to find this is: \n```sql\nSELECT COUNT(*) \nFROM employees e \nJOIN departments d ON e.department_id = d.department_id \nWHERE d.location_id IN (SELECT location_id FROM locations WHERE country_id = 'US');\n```\nThis query first joins the employees and departments tables on the department_id. It then filters for departments that are located in America, represented by the country_id 'US' in the locations table. The count of these employees is then returned."

In [10]:
response = "There are 35 employees whose departments are in the United Kingdom.\n\nThe SQL query I used to get this answer is:\n```sql\nSELECT COUNT(*) \nFROM employees e \nJOIN departments d ON e.department_id = d.department_id \nWHERE d.location_id IN (\n    SELECT location_id \n    FROM locations \n    WHERE country_id = 'UK'\n);\n```\nThis query first joins the employees and departments tables on the department_id column. Then it filters the rows where the location_id of the department is in the list of location_ids that are in the UK (obtained from the locations table). Finally, it counts the number of such employees."
print(response)

There are 35 employees whose departments are in the United Kingdom.

The SQL query I used to get this answer is:
```sql
SELECT COUNT(*) 
FROM employees e 
JOIN departments d ON e.department_id = d.department_id 
WHERE d.location_id IN (
    SELECT location_id 
    FROM locations 
    WHERE country_id = 'UK'
);
```
This query first joins the employees and departments tables on the department_id column. Then it filters the rows where the location_id of the department is in the list of location_ids that are in the UK (obtained from the locations table). Finally, it counts the number of such employees.
