# SQL agent

## Setup

### Imports

In [1]:
import os
import re
import pickle
from typing import (
    List,
    Any
)

import db_connect
from test_data import TestData
from prompt_generator import PromptGenerator
from examples import FEW_SHOT_EXAMPLES
from prompt import (
    PREFIX,
    SUFFIX,
    TABLE_DESCRIPTIONS
)

from langsmith import Client

from langchain_core.language_models import BaseLanguageModel

from langchain.prompts import (
    PromptTemplate, 
    ChatPromptTemplate
)

from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits.sql import base
from langchain.agents.agent import AgentExecutor
from langchain.agents import create_openai_tools_agent, create_react_agent, create_sql_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database import tool as sql_tools

from langchain_community.llms.ollama import Ollama
from langchain_community.chat_models.ollama import ChatOllama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_experimental.llms.ollama_functions import (
    OllamaFunctions,
    DEFAULT_RESPONSE_FUNCTION
)

from langchain_core.output_parsers import StrOutputParser

from langchain_core.messages.human import HumanMessage
from langchain_core.messages.ai import AIMessage

from langchain_experimental.llms.ollama_functions import convert_to_ollama_tool
from langchain.schema.agent import AgentFinish, AgentActionMessageLog

from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.tools import (tool, BaseTool)

from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.openai_tools.base import convert_to_openai_tool
from langchain.agents.format_scratchpad.openai_functions import (
    format_to_openai_function_messages, 
)
from langchain.agents.output_parsers.openai_functions import (
    OpenAIFunctionsAgentOutputParser,
)

from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.tracers import ConsoleCallbackHandler

from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

### LangSmith

In [2]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [3]:
llama3_embeddings = OllamaEmbeddings(
    model="llama3:8b", 
    temperature=0
)

llama3 = Ollama(
    model="llama3:instruct", 
    temperature=0
)
llama3_q8 = Ollama(
    model="llama3:8b-instruct-q8_0", 
    temperature=0
)

llama3__chat = ChatOllama(
    model="llama3:instruct", 
    temperature=0
)
llama3_q8__chat = ChatOllama(
    model="llama3:8b-instruct-q8_0", 
    temperature=0
)

llama3__with_functions = OllamaFunctions(
    model="llama3:instruct",
    format="json",
    temperature=0
)
llama3_q8__with_functions = OllamaFunctions(
    model="llama3:8b-instruct-q8_0",
    format="json",
    temperature=0
)

### Connect to DB with Readonly role

In [4]:
db = db_connect.get_db()

#### Check connection

In [5]:
db.run("select * from passenger")

"[(11, 'John'), (12, 'James'), (13, 'Poul'), (14, 'Christofer'), (15, 'Superman')]"

-----

## Create agent

### Prompt and error handler

In [6]:
prompt_generator = PromptGenerator().set_example_selector(
        examples=FEW_SHOT_EXAMPLES,
        embedding_llm=llama3_embeddings,
        k=3
)

### Create functions for `OllamaFunctions`

#### Pydentic classes (Functions)

In [7]:
class QuerySQLDatabaseTool(BaseModel):
    """INPUT to this tool is a DETAILED and CORRECT SQL query, 
    OUTPUT is a RESULT FROM the DATABASE. If the query is not 
    correct, an error message will be returned. If an error is 
    returned, rewrite the query, check the query, and try again. 
    If you encounter an issue with Unknown column 'xxxx' in 
    'field list', use sql_db_schema to query the correct table fields."""
    sql_query: str = Field(
        description="Detailed and correct SQL query."
    )


class InfoSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a comma-separated LIST OF TABLES, 
    output is the schema and sample rows for those tables. 
    Be sure that the tables actually exist by calling 
    sql_db_list_tables first! Example Input: table1, table2, table3.
    This tool will not help if you need to get information about a column!"""
    list_of_tables: str = Field(
        description="Comma-separated list of tables. Example: table1, table2, table3"
    )


class ListSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a EMPTY STRING, 
    OUTPUT is a comma-separated LIST OF TABLES 
    in database. Use this tool to select the 
    tables needed to respond to the user."""
    empty_string: str = Field(description="Empty string")


class QuerySQLCheckerTool(BaseModel):
    """INPUT to this tool is a QUERY to check, 
    OUTPUT is a CORRECT QUERY in database. 
    Use this tool to double check if 
    your query is correct before executing it. 
    Always use this tool before executing a 
    query with sql_db_query!"""
    query_to_check: str = Field(
        description="SQL query that needs to be checked before execution."
    )


functions = [
    QuerySQLDatabaseTool, 
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
]
functions = [convert_to_openai_tool(f)["function"] for f in functions]

#### Tools

In [8]:
key_words = [
    "CONSTRAINT",
    "CHECK",
    "UNIQUE",
    "PRIMARY",
    "FOREIGN",
    "EXCLUDE",
    "DEFERRABLE",
    "NOT",
    "INITIALLY",
    "LIKE",
]

def exclude_key_words_from_list(
    list: List[str], 
    key_words: List[str]
) -> List[str]:
    return [ 
        column_name_or_key_word 
        for column_name_or_key_word in list
        if column_name_or_key_word.upper() not in key_words
    ]

def extract_names_from_db(create_table_query: str, key_words_to_exclude: List[str]):
    # Removes VARCHAR(60) and others for subsequent processing
    clean_sql = re.sub(r"\(\d+\)", "", create_table_query)
    
    # Removes sample rows in table info
    clean_sql = re.sub(r"/\*(.|\s)*?\*/", "", clean_sql)
    
    # Get all table names
    table_names = re.findall(r"CREATE TABLE (\w+) \(", clean_sql)
    
    res = []
    
    for name in table_names:
        # Returns a description {names}
        table_sql = re.search(fr"CREATE TABLE {name} \((\s+(.|\s+)*?)\s+\)", clean_sql).group(1)
        
        # Gets all column names and constraints (this is noise)
        columns_part = re.findall(r"\n\t(\w+)", table_sql)
        
        res.append({
            "name": name,
            # Exclude constraints (noise)
            "columns": exclude_key_words_from_list(
                columns_part, key_words_to_exclude
            )
        })
    
    return res

In [9]:
def QuerySQLDatabaseTool(sql_query: str):
    """INPUT to this tool is a DETAILED and CORRECT SQL query, 
    OUTPUT is a RESULT FROM the DATABASE. If the query is not 
    correct, an error message will be returned. If an error is 
    returned, rewrite the query, check the query, and try again. 
    If you encounter an issue with Unknown column 'xxxx' in 
    'field list', use sql_db_schema to query the correct table fields."""
    
    res = {}
    
    try:
        res["type"] = "ok"
        res["result"] = db.run(sql_query)
    except Exception as e:
        res["type"] = "error"
        res["result"] = str(e)
        
    return res

def InfoSQLDatabaseTool(list_of_tables: str):
    """INPUT to this tool is a comma-separated LIST OF TABLES, 
    output is the schema and sample rows for those tables. 
    Be sure that the tables actually exist by calling 
    sql_db_list_tables first! Example Input: table1, table2, table3.
    This tool will not help if you need to get information about a column!"""
    
    res = ""
    res += "company, " if "company" in list_of_tables else ""
    res += "pass_in_trip, " if "pass_in_trip" in list_of_tables else ""
    res += "passenger, " if "passenger" in list_of_tables else ""
    res += "trip, " if "trip" in list_of_tables else ""
    
    res = res[:-2]
    
    return db.get_table_info_no_throw(
            [t.strip() for t in res.split(",")]
        )

def ListSQLDatabaseTool(empty_string: str = ""):
    """INPUT to this tool is a EMPTY STRING, 
    OUTPUT is a comma-separated LIST OF TABLES 
    in database. Use this tool to select the 
    tables needed to respond to the user."""
    return ", ".join(db.get_usable_table_names())

QUERY_CHECKER = """
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Also check that all column and table names correspond to the database, these are the ones you need to check:

{db_names}

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

OUTPUT THE FINAL SQL QUERY ONLY."""

checker_prompt = ChatPromptTemplate.from_messages((
    ("system", QUERY_CHECKER),
    ("human", "SQL Query: SELECT name FROM passenger;"),
    ("ai", "SELECT passenger_name FROM passenger;"),
    ("human", "SQL Query: SELECT plane FROM flight;"),
    ("ai", "SELECT plane FROM trip;"),
    ("human", "SQL Query: {query}")
))

def QuerySQLCheckerTool(query_to_check: str):
    """INPUT to this tool is a QUERY to check, 
    OUTPUT is a CORRECT QUERY in database. 
    Use this tool to double check if 
    your query is correct before executing it. 
    Always use this tool before executing a 
    query with sql_db_query!"""
    db_names_list = extract_names_from_db(db.get_table_info(), key_words)
    
    db_names = ""
    for name_pair in db_names_list:
        db_names += f"Table name: {name_pair["name"]}\n\tColumn names: {name_pair["columns"]}\n"
    
    chain = (
        checker_prompt.partial(db_names=db_names) 
        | llama3__chat 
        | StrOutputParser()
    )
    return chain.invoke({
        "query": query_to_check,
        "dialect": db.dialect,
    })

tools = [
    QuerySQLDatabaseTool,
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
]

#### Routing

### Function to get agent with provided llm model

In [10]:
def format_to_ollama_chat_messages(intermediate_steps):
    agent_scratchpad = []
    
    for step in intermediate_steps:
        agent_scratchpad.append(
            AIMessage(
                content="{ " +\
                    f"\"tool\": \"{step[0].tool}\", " +\
                    f"\"tool_input\": \"{step[0].tool_input}\""
            )
        )
        
        tool_result = step[1]
        
        if step[0].tool == "QuerySQLDatabaseTool":
            if tool_result["type"] == "ok":
                human_message = "This is similar to the database " +\
                    "response. Maybe there's something I need there. " +\
                    "I can't see it, please help me understand, if " +\
                    "there is an answer to my question, please give it " +\
                    "to me, use it for this: \"__conversational_response\""
            elif tool_result["type"] == "error":
                human_message = "Oops.. There seems to be a mistake here. " +\
                    "Apparently the request was incorrect, try to fix it, " +\
                    "maybe it will work and try again! Maybe it will help " +\
                    "you \"InfoSQLDatabaseTool\" to better understand the structure." +\
                    "REWRITE QUERY AND COME BACK WITH ANSWER!"
            tool_result = tool_result["result"]
        elif step[0].tool == "InfoSQLDatabaseTool":
            human_message = "It looks like the structure of the database " +\
                "tables(s). Maybe this will help you write the right query? " +\
                "Please make sure you remember the names!"
        elif step[0].tool == "ListSQLDatabaseTool":
            human_message = "It looks like a list of table names. " +\
                "Maybe it will help to get an answer. I think if you don't " +\
                "have enough information, you can use this: \"InfoSQLDatabaseTool\""
        elif step[0].tool == "QuerySQLCheckerTool":
            human_message = "It looks like an SQL query! I think it's the " +\
                "right one. Try to execute this query."
        
        agent_scratchpad.append(
            HumanMessage(
                content="I can help you execute the tool. Give me a second... " +
                    "And so, I think I managed to call it, but I can't read " +
                    "what's written there, only a smart AI can understand it." +
                    f"\nI checked, the result of the tool\n{human_message}" +
                    f"\n\nResult of tool: \n{tool_result}"
            )
        )
        
    return agent_scratchpad

def get_agent(llm: OllamaFunctions, functions: List[BaseModel]) -> AgentExecutor:
    """Create and return agent with provided `llm` model"""

    llm_with_tools = llm.bind_tools(tools=functions + [DEFAULT_RESPONSE_FUNCTION])
    
    agent = (
        RunnablePassthrough.assign(
            agent_scratchpad=lambda steps: format_to_ollama_chat_messages(steps['intermediate_steps'])
        )
        | prompt_generator.get_prompt()
        | llm_with_tools
        | OpenAIFunctionsAgentOutputParser()
    )
    
    return agent

In [11]:
agent = get_agent(llama3__with_functions, functions)

def run_agent(user_input: str) -> AgentFinish:
    intermediate_steps = []
    while True:
        result = agent.invoke({
            "dialect": "PostgreSQL",
            "input": user_input,
            "top_k": "20", 
            "intermediate_steps": intermediate_steps
        })
        
        if isinstance(result, AgentFinish):
            return result
        else:
            result: AgentActionMessageLog = result
        
        tool = {
            "QuerySQLDatabaseTool": QuerySQLDatabaseTool, 
            "InfoSQLDatabaseTool": InfoSQLDatabaseTool,
            "ListSQLDatabaseTool": ListSQLDatabaseTool, 
            "QuerySQLCheckerTool": QuerySQLCheckerTool,
        }[result.tool]
        print(f"result\t{result}\n") # TODO: DELETE THIS ROW
        if isinstance(result.tool_input, str):
            observation = tool(result.tool_input)
        else:
            observation = tool(**result.tool_input)
        intermediate_steps.append((result, observation))

-----
## Testing an agent with different models and save the results

In [12]:
res = []

In [13]:
try:
    res.append(run_agent(TestData.QUESTIONS[0]))
except Exception as e:
    res.append(str(e))

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': {'sql_query': 'SELECT name FROM passenger ORDER BY name LIMIT 5;'}}
result	tool='QuerySQLDatabaseTool' tool_input={'sql_query': 'SELECT name FROM passenger ORDER BY name LIMIT 5;'} log="\nInvoking: `QuerySQLDatabaseTool` with `{'sql_query': 'SELECT name FROM passenger ORDER BY name LIMIT 5;'}`\n\n\n" message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '{"sql_query": "SELECT name FROM passenger ORDER BY name LIMIT 5;"}'}}, id='run-882f5fc3-3d8b-4951-aee3-f6355a48f8ab-0')]

in ollama	{'tool': 'QuerySQLCheckerTool', 'tool_input': {'query_to_check': "SELECT * FROM passenger WHERE EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'passenger' AND column_name = 'name');"}}
result	tool='QuerySQLCheckerTool' tool_input={'query_to_check': "SELECT * FROM passenger WHERE EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'passenger' AND column

In [14]:
try:
    res.append(run_agent(TestData.QUESTIONS[1]))
except Exception as e:
    res.append(str(e))

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': {'sql_query': 'SELECT * FROM public."company";'}}
result	tool='QuerySQLDatabaseTool' tool_input={'sql_query': 'SELECT * FROM public."company";'} log='\nInvoking: `QuerySQLDatabaseTool` with `{\'sql_query\': \'SELECT * FROM public."company";\'}`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '{"sql_query": "SELECT * FROM public.\\"company\\";"}'}}, id='run-d9ae0cca-00fb-4acc-ba6c-c46b65ce3723-0')]

in ollama	{'tool': '__conversational_response', 'tool_input': {'response': "Ah, I see! It looks like there's only one company listed in the database. According to the result, it seems that American Airlines is the only airline present. Would you like me to provide more information about this airline or help with something else?"}}


In [15]:
try:
    res.append(run_agent(TestData.QUESTIONS[2]))
except Exception as e:
    res.append(str(e))

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': {'sql_query': 'SELECT DISTINCT "plane" FROM public."trip" WHERE EXTRACT(STATE FROM  "departure_city") = \'Washington\';'}}
result	tool='QuerySQLDatabaseTool' tool_input={'sql_query': 'SELECT DISTINCT "plane" FROM public."trip" WHERE EXTRACT(STATE FROM  "departure_city") = \'Washington\';'} log='\nInvoking: `QuerySQLDatabaseTool` with `{\'sql_query\': \'SELECT DISTINCT "plane" FROM public."trip" WHERE EXTRACT(STATE FROM  "departure_city") = \\\'Washington\\\';\'}`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '{"sql_query": "SELECT DISTINCT \\"plane\\" FROM public.\\"trip\\" WHERE EXTRACT(STATE FROM  \\"departure_city\\") = \'Washington\';"}'}}, id='run-2997ea82-cbeb-4ba2-9a6f-b81715f89a86-0')]

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': "{'sql_query': 'SELECT DISTINCT "}
result	tool='QuerySQLDatabaseTool' tool_input="{'sql_query': 'SELECT 

In [16]:
try:
    res.append(run_agent(TestData.QUESTIONS[3]))
except Exception as e:
    res.append(str(e))

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': {'sql_query': 'SELECT p."passenger_name" FROM "passenger" p WHERE RIGHT(p."passenger_name", 4) = \'man\';'}}
result	tool='QuerySQLDatabaseTool' tool_input={'sql_query': 'SELECT p."passenger_name" FROM "passenger" p WHERE RIGHT(p."passenger_name", 4) = \'man\';'} log='\nInvoking: `QuerySQLDatabaseTool` with `{\'sql_query\': \'SELECT p."passenger_name" FROM "passenger" p WHERE RIGHT(p."passenger_name", 4) = \\\'man\\\';\'}`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '{"sql_query": "SELECT p.\\"passenger_name\\" FROM \\"passenger\\" p WHERE RIGHT(p.\\"passenger_name\\", 4) = \'man\';"}'}}, id='run-e2ead174-025b-4599-a998-6e001449e60a-0')]

in ollama	{'__conversational_response': "Here are the names of people that end in 'man':\n\n- Adam\n- Brandon\n- Cameron\n- Damian\n- Emanuel\n- Freeman\n- Hanson\n- Leman\n- Norman\n- Shaman"}


In [17]:
try:
    res.append(run_agent(TestData.QUESTIONS[4]))
except Exception as e:
    res.append(str(e))

in ollama	{'tool': 'QuerySQLDatabaseTool', 'tool_input': {'sql_query': "SELECT COUNT(*) FROM trip WHERE plane = 'Airbus A319';"}}
result	tool='QuerySQLDatabaseTool' tool_input={'sql_query': "SELECT COUNT(*) FROM trip WHERE plane = 'Airbus A319';"} log='\nInvoking: `QuerySQLDatabaseTool` with `{\'sql_query\': "SELECT COUNT(*) FROM trip WHERE plane = \'Airbus A319\';"}`\n\n\n' message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'QuerySQLDatabaseTool', 'arguments': '{"sql_query": "SELECT COUNT(*) FROM trip WHERE plane = \'Airbus A319\';"}'}}, id='run-9c201e42-6764-48bb-b67f-69298c529674-0')]

in ollama	{'tool': '__conversational_response', 'tool_input': {'response': "Unfortunately, it seems that there are no flights recorded in the database using Airbus A319. If you'd like to know more about our fleet or flight schedules, I'd be happy to help!"}}


In [25]:
for (question, agent_answer, correct_answer) in zip(
    TestData.QUESTIONS[:4], 
    res, 
    TestData.ANSWER[:4]
):
    print(
        f"Qustion:\t{question}\nCorrect answer:\t{correct_answer}\nAgent answer:\t{agent_answer}", 
        end="\n\n" + "-" * 200 + "\n\n"
    )

Qustion:	Print the first 5 names of the passengers, sorted by name
Correct answer:	Here are the names you asked for: 'Alice', 'Bob', 'Charlie', 'Christofer', 'David'
Agent answer:	Failed to parse a function call from llama3:instruct output: { "tool": "QuerySQLExecutorTool", "tool_input": "'query_to_execute': 'SELECT * FROM passenger;'"}

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Qustion:	Print the names of all companies
Correct answer:	Here are the names you asked for: 'American Airlines', 'S7 Airlines', 'Nordwind Airlines'
Agent answer:	return_values={'output': "Ah, I see! It looks like there's only one company listed in the database. According to the result, it seems that American Airlines is the only airline present. Would you like me to provide more information about this airline or help with something else?"} log="Ah, I see