# SQL agent

## Setup

### Imports

In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

In [2]:
import re
import pickle
from typing import List


from app.databases.external_db import get_db
from app.test_datasets import TestData
from app.prompts import PromptGenerator
from app.examples import FEW_SHOT_EXAMPLES


from langsmith import Client


from langchain.prompts import ChatPromptTemplate

from langchain.schema.agent import AgentFinish, AgentActionMessageLog

from langchain.agents.openai_tools.base import convert_to_openai_tool
from langchain.agents.output_parsers.openai_functions import (
    OpenAIFunctionsAgentOutputParser,
)


from langchain_core.output_parsers import StrOutputParser

from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage

from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_core.runnables import RunnablePassthrough


from langchain_community.chat_models.ollama import ChatOllama
from langchain_community.embeddings.ollama import OllamaEmbeddings


from langchain_experimental.llms.ollama_functions import (
    OllamaFunctions,
    DEFAULT_RESPONSE_FUNCTION
)

### LangSmith

In [3]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [4]:
llama3_embeddings = OllamaEmbeddings(
    model="llama3:8b", 
    temperature=0
)

llama3__chat = ChatOllama(
    model="llama3:instruct", 
    temperature=0
)
llama3_q8__chat = ChatOllama(
    model="llama3:8b-instruct-q8_0", 
    temperature=0
)

llama3__with_functions = OllamaFunctions(
    model="llama3:instruct",
    format="json",
    temperature=0
)
llama3_q8__with_functions = OllamaFunctions(
    model="llama3:8b-instruct-q8_0",
    format="json",
    temperature=0
)

### Connect to DB with Readonly role

In [5]:
db = get_db()

#### Check connection

In [None]:
db.run("select * from passenger")

-----

## Create agent

### Get prompt generator

In [None]:
prompt_generator = PromptGenerator().set_example_selector(
        examples=FEW_SHOT_EXAMPLES,
        embedding_llm=llama3_embeddings,
        k=3
)

### Create functions for `OllamaFunctions`

#### Pydentic classes (Functions)

In [None]:
tool_names = [
    "__query_sql_database_tool",
    "__info_sql_database_tool",
    "__list_sql_database_tool",
    "__query_sql_checker_tool"
]

In [None]:
class QuerySQLDatabaseTool(BaseModel):
    """INPUT to this tool is a DETAILED and CORRECT SQL query, 
    OUTPUT is a RESULT FROM the DATABASE. If the query is not 
    correct, an error message will be returned. If an error is 
    returned, rewrite the query, check the query, and try again. 
    If you encounter an issue with Unknown column \'xxxx\' in 
    \'field list\', use \'__info_sql_database_tool\' to query the correct table fields."""
    sql_query: str = Field(
        description="Detailed and correct SQL query."
    )


class InfoSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a comma-separated LIST OF TABLES, 
    output is the schema and sample rows for those tables. 
    Be sure that the tables actually exist by calling 
    \'__list_sql_database_tool\' first! Example Input: table1, table2, table3.
    This tool will not help if you need to get information about a column!"""
    list_of_tables: str = Field(
        description="Comma-separated list of tables. Example: table1, table2, table3"
    )


class ListSQLDatabaseTool(BaseModel):
    """INPUT to this tool is a EMPTY STRING, 
    OUTPUT is a comma-separated LIST OF TABLES 
    in database. Use this tool to select the 
    tables needed to respond to the user."""
    empty_string: str = Field(description="Empty string")


class QuerySQLCheckerTool(BaseModel):
    """INPUT to this tool is a QUERY to check, 
    OUTPUT is a CORRECT QUERY in database. 
    Use this tool to double check if 
    your query is correct before executing it. 
    Always use this tool before executing a 
    query with \'__query_sql_database_tool\'!"""
    query_to_check: str = Field(
        description="SQL query that needs to be checked before execution."
    )


functions = [
    QuerySQLDatabaseTool, 
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
]

# Post-processing of functions
functions = [convert_to_openai_tool(f)["function"] for f in functions]

for (i, f) in enumerate(functions):
    f["name"] = tool_names[i]
    f["description"] = re.sub(r"\s+", " ", f["description"])

#### Tools (Python functions) for execute in agent

In [None]:
key_words = [
    "CONSTRAINT",
    "CHECK",
    "UNIQUE",
    "PRIMARY",
    "FOREIGN",
    "EXCLUDE",
    "DEFERRABLE",
    "NOT",
    "INITIALLY",
    "LIKE",
]

def exclude_key_words_from_list(
    list: List[str], 
    key_words: List[str]
) -> List[str]:
    return [ 
        column_name_or_key_word 
        for column_name_or_key_word in list
        if column_name_or_key_word.upper() not in key_words
    ]

def extract_names_from_db(create_table_query: str, key_words_to_exclude: List[str]):
    # Removes VARCHAR(60) and others for subsequent processing
    clean_sql = re.sub(r"\(\d+\)", "", create_table_query)
    
    # Removes sample rows in table info
    clean_sql = re.sub(r"/\*(.|\s)*?\*/", "", clean_sql)
    
    # Get all table names
    table_names = re.findall(r"CREATE TABLE (\w+) \(", clean_sql)
    
    res = []
    
    for name in table_names:
        # Returns a description {names}
        table_sql = re.search(fr"CREATE TABLE {name} \((\s+(.|\s+)*?)\s+\)", clean_sql).group(1)
        
        # Gets all column names and constraints (this is noise)
        columns_part = re.findall(r"\n\t(\w+)", table_sql)
        
        res.append({
            "name": name,
            # Exclude constraints (noise)
            "columns": exclude_key_words_from_list(
                columns_part, key_words_to_exclude
            )
        })
    
    return res

In [None]:
def QuerySQLDatabaseTool(sql_query: str):
    res = {}
    
    try:
        res["type"] = "ok"
        res["result"] = db.run(sql_query)
    except Exception as e:
        res["type"] = "error"
        res["result"] = str(e)
        
    return res

def InfoSQLDatabaseTool(list_of_tables: str):
    res = ""
    res += "company, " if "company" in list_of_tables else ""
    res += "pass_in_trip, " if "pass_in_trip" in list_of_tables else ""
    res += "passenger, " if "passenger" in list_of_tables else ""
    res += "trip, " if "trip" in list_of_tables else ""
    
    res = res[:-2]
    
    return db.get_table_info_no_throw(
            [t.strip() for t in res.split(",")]
        )

def ListSQLDatabaseTool(empty_string: str = ""):
    return ", ".join(db.get_usable_table_names())

QUERY_CHECKER = """
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Also check that all column and table names correspond to the database, these are the ones you need to check:

{db_names}

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

OUTPUT THE FINAL SQL QUERY ONLY."""

checker_prompt = ChatPromptTemplate.from_messages((
    ("system", QUERY_CHECKER),
    ("human", "SQL Query: SELECT name FROM passenger;"),
    ("ai", "SELECT passenger_name FROM passenger;"),
    ("human", "SQL Query: SELECT plane FROM flight;"),
    ("ai", "SELECT plane FROM trip;"),
    ("human", "SQL Query: {query}")
))

def QuerySQLCheckerTool(query_to_check: str):
    db_names_list = extract_names_from_db(db.get_table_info(), key_words)
    
    db_names = ""
    for name_pair in db_names_list:
        db_names += f"Table name: {name_pair["name"]}\n\tColumn names: {name_pair["columns"]}\n"
    
    chain = (
        checker_prompt.partial(db_names=db_names) 
        | llama3__chat 
        | StrOutputParser()
    )
    return chain.invoke({
        "query": query_to_check,
        "dialect": db.dialect,
    })

tools = [
    QuerySQLDatabaseTool,
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
]

### Util functions

In [None]:
def format_to_ollama_chat_messages(intermediate_steps):
    agent_scratchpad = []
    
    for step in intermediate_steps:
        agent_scratchpad.append(
            AIMessage(
                content="{ " +\
                    f"\'tool\': \'{step[0].tool}\', " +\
                    f"\'tool_input\': {step[0].tool_input}"
            )
        )
        
        tool_result = step[1]
        
        if step[0].tool == tool_names[0]:
            if tool_result["type"] == "ok":
                human_message = "This is similar to the database " +\
                    "response. Maybe there's something I need there. " +\
                    "I can't see it, please help me understand, if " +\
                    "there is an response to my question, please give it " +\
                    f"to me, use it for this: \'{DEFAULT_RESPONSE_FUNCTION["name"]}\'"
            elif tool_result["type"] == "error":
                human_message = "Oops.. There seems to be a mistake here. " +\
                    "Apparently the request was incorrect, try to fix it! " +\
                    f"Use \'{tool_names[1]}\' to get info about database" +\
                    "\nREWRITE QUERY AND COME BACK WITH RESPONSE!"
            tool_result = tool_result["result"]
        elif step[0].tool == tool_names[1]:
            human_message = "It looks like the structure of the database " +\
                "tables(s). Maybe this will help you write the right query? " +\
                "Please make sure you remember the names!"
        elif step[0].tool == tool_names[2]:
            human_message = "It looks like a list of table names. " +\
                "Maybe it will help to get an response. I think if you don't " +\
                f"have enough information, you can use this: \'{tool_names[1]}\'"
        elif step[0].tool == tool_names[3]:
            human_message = "It looks like an SQL query! I think it's the " +\
                "right one. Try to execute this query."
        
        agent_scratchpad.append(
            HumanMessage(
                content="I can help you execute the tool. Give me a second... " +
                    "And so, I think I managed to call it, but I can't read " +
                    "what's written there, only a smart AI can understand it." +
                    f"\nI checked, the result of the tool\n{human_message}" +
                    f"\n\nResult of tool: \n{tool_result}"
            )
        )
        
    return agent_scratchpad

def get_agent(llm: OllamaFunctions, functions: List[BaseModel]):
    """Create and return agent with provided `llm` model"""

    llm_with_tools = llm.bind_tools(tools=functions + [DEFAULT_RESPONSE_FUNCTION])
    
    agent = (
        RunnablePassthrough.assign(
            agent_scratchpad=lambda steps: format_to_ollama_chat_messages(steps['intermediate_steps'])
        )
        | prompt_generator.get_prompt()
        | llm_with_tools
        | OpenAIFunctionsAgentOutputParser()
    )
    
    return agent

### Cycle for execute agent

In [None]:
agent = get_agent(llama3__with_functions, functions)

def run_agent(user_input: str, iter_limit = 10) -> AgentFinish | str:
    number_of_iteration = 0
    intermediate_steps = []
    
    while number_of_iteration < iter_limit:
        result = agent.invoke({
            "dialect": "PostgreSQL",
            "input": user_input,
            "top_k": "20", 
            "intermediate_steps": intermediate_steps
        })
        
        if isinstance(result, AgentFinish):
            return result
        else:
            result: AgentActionMessageLog = result
        
        tool = {
            tool_names[0]: QuerySQLDatabaseTool, 
            tool_names[1]: InfoSQLDatabaseTool,
            tool_names[2]: ListSQLDatabaseTool, 
            tool_names[3]: QuerySQLCheckerTool,
        }[result.tool]
        
        if isinstance(result.tool_input, str):
            observation = tool(result.tool_input)
        else:
            observation = tool(**result.tool_input)
        intermediate_steps.append((result, observation))
        number_of_iteration += 1
    
    return "Agent stop due to limited number of iterations!"

-----
## Testing an agent

In [None]:
results = []

In [None]:
for question in TestData.QUESTIONS:
    try:
        results.append(run_agent(question))
    except Exception as e:
        results.append(str(e))

In [6]:
with open("..\\data\\ollama_functions_agent\\v1\\llama3_inst.pickle", "wb") as f:
    pickle.dump(results, f)