# SQL agent

## Setup

### Imports

In [2]:
import os
import re
import pickle
from typing import (
    List,
    Any
)

import db_connect
from test_data import TestData
from prompt import (
    PREFIX,
    SUFFIX,
    TABLE_DESCRIPTIONS
)

from langsmith import Client

from langchain_core.language_models import BaseLanguageModel

from langchain.prompts import PromptTemplate

from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits.sql import base
from langchain.agents.agent import AgentExecutor
from langchain.agents import create_openai_tools_agent, create_react_agent, create_sql_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database import tool as sql_tools

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_experimental.llms.ollama_functions import (
    OllamaFunctions,
    DEFAULT_RESPONSE_FUNCTION
)

from langchain_core.output_parsers import StrOutputParser

from langchain_core.messages.human import HumanMessage
from langchain_core.messages.ai import AIMessage

from langchain_experimental.llms.ollama_functions import convert_to_ollama_tool
from langchain.schema.agent import AgentFinish

from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.tools import (tool, BaseTool)

from prompt_generator import PromptGenerator
from examples import FEW_SHOT_EXAMPLES

from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.format_scratchpad.openai_functions import (
    format_to_openai_function_messages,
)
from langchain.agents.output_parsers.openai_functions import (
    OpenAIFunctionsAgentOutputParser,
)

from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.tracers import ConsoleCallbackHandler

from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

### LangSmith

In [3]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [4]:
llama3_embeddings = OllamaEmbeddings(model="llama3:8b", temperature=0)

llama3 = Ollama(model="llama3:instruct", temperature=0)
llama3_q8 = Ollama(model="llama3:8b-instruct-q8_0", temperature=0)

llama3__with_functions = OllamaFunctions(model="llama3:instruct", format="json")
llama3_q8__with_functions = OllamaFunctions(model="llama3:8b-instruct-q8_0", format="json")

### Connect to DB with Readonly role

In [5]:
db = db_connect.get_db()

#### Check connection

In [17]:
db.run("select * from passenger")

"[(1, 'John'), (2, 'James'), (3, 'Poul'), (4, 'Christofer'), (5, 'Superman'), (6, 'Donald'), (7, 'Douglas'), (8, 'Dwight'), (9, 'Earl'), (10, 'Edgar'), (11, 'Edmund'), (12, 'Edwin'), (13, 'Elliot'), (14, 'Eric'), (15, 'Ernest'), (16, 'Ethan'), (17, 'Ezekiel'), (18, 'Felix'), (19, 'Franklin'), (20, 'Frederick'), (21, 'Gabriel'), (22, 'Joseph'), (23, 'Joshua'), (24, 'Julian'), (25, 'Alice'), (26, 'Bob'), (27, 'Charlie'), (28, 'David'), (29, 'Emily'), (30, 'Frank'), (31, 'George'), (32, 'Helen'), (33, 'Irene'), (34, 'Jack'), (35, 'Kate'), (36, 'Leo'), (37, 'Mary'), (38, 'Nancy'), (39, 'Oliver'), (40, 'Paul'), (41, 'Qiana'), (42, 'Robert'), (43, 'Samantha'), (44, 'Thomas'), (45, 'Victoria')]"

-----

## Create agent

### Prompt and error handler

In [18]:
prompt_generator = PromptGenerator().set_example_selector(
        examples=FEW_SHOT_EXAMPLES,
        embedding_llm=llama3_embeddings,
        k=3
)

### Create functions for `OllamaFunctions`

#### Pydentic classes (Functions)

In [19]:
class QuerySQLDatabaseTool(BaseModel):
    """INPUT to this tool is a DETAILED and CORRECT SQL query, 
    OUTPUT is a RESULT FROM the DATABASE. If the query is not 
    correct, an error message will be returned. If an error is 
    returned, rewrite the query, check the query, and try again. 
    If you encounter an issue with Unknown column 'xxxx' in 
    'field list', use sql_db_schema to query the correct table fields."""
    tool_input: str = Field(
        description="Detailed and correct SQL query."
    )


class InfoSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a comma-separated LIST OF TABLES, 
    output is the schema and sample rows for those tables. 
    Be sure that the tables actually exist by calling 
    sql_db_list_tables first! Example Input: table1, table2, table3.
    This tool will not help if you need to get information about a column!"""
    tool_input: str = Field(
        description="Comma-separated list of tables. Example: table1, table2, table3"
    )


class ListSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a EMPTY STRING, 
    OUTPUT is a comma-separated LIST OF TABLES 
    in database. Use this tool to select the 
    tables needed to respond to the user."""
    tool_input: str = Field(description="Empty string")


class QuerySQLCheckerTool(BaseModel):
    """INPUT to this tool is a QUERY to check, 
    OUTPUT is a CORRECT QUERY in database. 
    Use this tool to double check if 
    your query is correct before executing it. 
    Always use this tool before executing a 
    query with sql_db_query!"""
    tool_input: str = Field(
        description="SQL query that needs to be checked before execution."
    )
    

class ConversationalResponse(BaseModel):
    """INPUT to this tool is a message to 
    the user Use this tool to text replies 
    to the user or if no other tool is 
    suitable for you. For example, 
    if you are ready to answer a 
    question. Or if the user just 
    wants to chat."""
    message_to_user: str = Field(
        description="Message or response or answer to user."
    )



functions = [
    QuerySQLDatabaseTool, 
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    ConversationalResponse
]
functions = [convert_to_ollama_tool(f) for f in functions]
functions[-1]['name'] = DEFAULT_RESPONSE_FUNCTION['name']

#### Tools

In [138]:
key_words = [
    "CONSTRAINT",
    "CHECK",
    "UNIQUE",
    "PRIMARY",
    "FOREIGN",
    "EXCLUDE",
    "DEFERRABLE",
    "NOT",
    "INITIALLY",
    "LIKE",
]

def exclude_key_words_from_list(
    list: List[str], 
    key_words: List[str]
) -> List[str]:
    return [ 
        column_name_or_key_word 
        for column_name_or_key_word in list
        if column_name_or_key_word.upper() not in key_words
    ]

def extract_names_from_db(create_table_query: str, key_words_to_exclude: List[str]):
    # Removes VARCHAR(60) and others for subsequent processing
    clean_sql = re.sub(r"\(\d+\)", "", create_table_query)
    
    # Removes sample rows in table info
    clean_sql = re.sub(r"/\*(.|\s)*?\*/", "", clean_sql)
    
    # Get all table names
    table_names = re.findall(r"CREATE TABLE (\w+) \(", clean_sql)
    
    res = []
    
    for name in table_names:
        # Returns a description {names}
        table_sql = re.search(fr"CREATE TABLE {name} \((\s+(.|\s+)*?)\s+\)", clean_sql).group(1)
        
        # Gets all column names and constraints (this is noise)
        columns_part = re.findall(r"\n\t(\w+)", table_sql)
        
        res.append({
            "name": name,
            # Exclude constraints (noise)
            "columns": exclude_key_words_from_list(
                columns_part, key_words_to_exclude
            )
        })
    
    return res

In [20]:
@tool
def QuerySQLDatabaseTool(query: str):
    """INPUT to this tool is a DETAILED and CORRECT SQL query, 
    OUTPUT is a RESULT FROM the DATABASE. If the query is not 
    correct, an error message will be returned. If an error is 
    returned, rewrite the query, check the query, and try again. 
    If you encounter an issue with Unknown column 'xxxx' in 
    'field list', use sql_db_schema to query the correct table fields."""
    
    res = {}
    
    try:
        res["type"] = "ok"
        res["result"] = db.run(query)
    except Exception as e:
        res["type"] = "error"
        res["result"] = str(e)
        
    return res

@tool
def InfoSQLDatabaseTool(input_list: str):
    """INPUT to this tool is a comma-separated LIST OF TABLES, 
    output is the schema and sample rows for those tables. 
    Be sure that the tables actually exist by calling 
    sql_db_list_tables first! Example Input: table1, table2, table3.
    This tool will not help if you need to get information about a column!"""
    
    res = ""
    res += "company, " if "company" in input_list else ""
    res += "pass_in_trip, " if "pass_in_trip" in input_list else ""
    res += "passenger, " if "passenger" in input_list else ""
    res += "trip, " if "trip" in input_list else ""
    
    res = res[:-2]
    
    return db.get_table_info_no_throw(
            [t.strip() for t in res.split(",")]
        )

@tool
def ListSQLDatabaseTool(empty_string: str = ""):
    """INPUT to this tool is a EMPTY STRING, 
    OUTPUT is a comma-separated LIST OF TABLES 
    in database. Use this tool to select the 
    tables needed to respond to the user."""
    return ", ".join(db.get_usable_table_names())

QUERY_CHECKER = """
{query}
Double check the {dialect} query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Also check that all column and table names correspond to the database, these are the ones you need to check:

{db_names}

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only.

SQL Query: """

checker_prompt = PromptTemplate.from_template(
    template=QUERY_CHECKER
)

@tool
def QuerySQLCheckerTool(query_to_check: str):
    """INPUT to this tool is a QUERY to check, 
    OUTPUT is a CORRECT QUERY in database. 
    Use this tool to double check if 
    your query is correct before executing it. 
    Always use this tool before executing a 
    query with sql_db_query!"""
    db_names_list = extract_names_from_db(db.get_table_info(), key_words)
    
    db_names = ""
    for (names, columns) in (db_names_list["names"], db_names_list["columns"]):
        db_names += f"Table name: {names}\n\tColumn names: {columns}\n"
        
    chain = checker_prompt.partial(db_names=db_names) | llama3 | StrOutputParser()
    return chain.invoke({
        "query": query_to_check,
        "dialect": db.dialect,
    })

tools = [
    QuerySQLDatabaseTool,
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
]

#### Routing

### Function to get agent with provided llm model

In [21]:
def format_to_ollama_chat_messages(intermediate_steps):
    agent_scratchpad = []
    
    for step in intermediate_steps:
        agent_scratchpad.append(
            AIMessage(
                content=f"I use this tool \"{step[0].tool}\" " +
                f"with this input\nInput: {step[0].tool_input}"
            )
        )
        
        tool_result = step[1]
        
        if step[0].tool == "QuerySQLDatabaseTool":
            if tool_result["type"] == "ok":
                human_message = "This is similar to the database " +\
                    "response. Maybe there's something I need there. " +\
                    "I can't see it, please help me understand, if " +\
                    "there is an answer to my question, please give it " +\
                    "to me, use it for this: \"__conversational_response\""
            elif tool_result["type"] == "error":
                human_message = "Oops.. There seems to be a mistake here. " +\
                    "Apparently the request was incorrect, try to fix it, " +\
                    "maybe it will work and try again! Maybe it will help " +\
                    "you \"InfoSQLDatabaseTool\" to better understand the structure."
            tool_result = tool_result["result"]
        elif step[0].tool == "InfoSQLDatabaseTool":
            human_message = "It looks like the structure of the database " +\
                "tables(s). Maybe this will help you write the right query? " +\
                "Please make sure you remember the names!"
        elif step[0].tool == "ListSQLDatabaseTool":
            human_message = "It looks like a list of table names. " +\
                "Maybe it will help to get an answer. I think if you don't " +\
                "have enough information, you can use this: \"InfoSQLDatabaseTool\""
        elif step[0].tool == "QuerySQLCheckerTool":
            human_message = "It looks like an SQL query! I think it's the " +\
                "right one. It's worth running it and seeing what it gives, " +\
                "maybe there's an answer to my question. Help me, please!"
        
        agent_scratchpad.append(
            HumanMessage(
                content="I can help you execute the tool. Give me a second... " +
                    "And so, I think I managed to call it, but I can't read " +
                    "what's written there, only a smart AI can understand it." +
                    f"\nI checked, the result of the tool\n{human_message}" +
                    f"\n\nResult of tool: \n{tool_result}"
            )
        )
        
    return agent_scratchpad

def get_agent(llm: OllamaFunctions, functions: List[BaseModel]) -> AgentExecutor:
    """Create and return agent with provided `llm` model"""

    llm_with_tools = llm.bind_tools(tools=functions)
    llm_with_tools.bind(verbose=True)
    agent = (
        RunnablePassthrough.assign(
            agent_scratchpad=lambda steps: format_to_ollama_chat_messages(steps['intermediate_steps'])
        )
        | prompt_generator.get_prompt()
        | llm_with_tools
        | OpenAIFunctionsAgentOutputParser()
    )
    
    return agent

In [22]:
agent = get_agent(llama3__with_functions, functions)

def run_agent(user_input):
    intermediate_steps = []
    while True:
        result = agent.invoke({
            "dialect": "PostgreSQL",
            "input": TestData.QUESTIONS[0],
            "top_k": "3", 
            "intermediate_steps": intermediate_steps
        })
        if isinstance(result, AgentFinish):
            return result
        print(f"result: {result}")
        tool = {
            "QuerySQLDatabaseTool": QuerySQLDatabaseTool, 
            "InfoSQLDatabaseTool": InfoSQLDatabaseTool,
            "ListSQLDatabaseTool": ListSQLDatabaseTool, 
            "QuerySQLCheckerTool": QuerySQLCheckerTool,
        }[result.tool]
        observation = tool.run(result.tool_input)
        intermediate_steps.append((result, observation))

-----
## Testing an agent with different models and save the results

In [23]:
print(TestData.ANSWER[0])

Here are the names you asked for: 'Alice', 'Bob', 'Charlie', 'Christofer', 'David'


In [24]:
res = run_agent("")

{'tool': 'ListSQLDatabaseTool', 'tool_input': ''}
result: tool='ListSQLDatabaseTool' tool_input={} log='\nInvoking: `ListSQLDatabaseTool` with `{}`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'ListSQLDatabaseTool', 'arguments': ''}}, id='run-411b502b-8db4-49aa-b399-0361d090725f-0')]
{'tool': 'QuerySQLDatabaseTool', 'tool_input': 'SELECT name FROM passenger ORDER BY name LIMIT 5;'}
result: tool='QuerySQLDatabaseTool' tool_input='SELECT name FROM passenger ORDER BY name LIMIT 5;' log='\nInvoking: `QuerySQLDatabaseTool` with `SELECT name FROM passenger ORDER BY name LIMIT 5;`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '"SELECT name FROM passenger ORDER BY name LIMIT 5;"'}}, id='run-5632bfd3-4de2-4de7-8eca-5f193c24afac-0')]
{'tool': 'QuerySQLCheckerTool', 'tool_input': 'SELECT p.first_name, p.last_name FROM passenger p ORDER BY p.first_name, p.last_name LIMIT 5;'}
res

In [None]:
print(res)