# SQL Functions, based `OllamaFunctions`

## Setup

### Imports

In [1]:
import os
import pickle

import db_connect
from test_data import TestData

from langsmith import Client

from langchain_core.language_models import BaseLanguageModel

from langchain.prompts import PromptTemplate

from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits.sql import base
from langchain.agents.agent import AgentExecutor
from langchain.agents import tool

from langchain_experimental.llms.ollama_functions import OllamaFunctions, DEFAULT_RESPONSE_FUNCTION, DEFAULT_SYSTEM_TEMPLATE
from langchain_experimental.llms.ollama_functions import convert_to_ollama_tool, parse_response
from langchain_core.utils.function_calling import convert_to_openai_function

from langchain_core.prompts import PromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableLambda

from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.agents.output_parsers.openai_functions import OpenAIFunctionsAgentOutputParser

from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_core.messages.system import SystemMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.function import FunctionMessage
from langchain_core.messages.ai import AIMessage

from langchain.schema.agent import AgentFinish

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import HumanMessagePromptTemplate

from langchain.callbacks.tracers import ConsoleCallbackHandler

### LangSmith

In [2]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [3]:
llama3 = OllamaFunctions(model="llama3:8b", format="json")
llama3_inst = OllamaFunctions(model="llama3:instruct", format="json")
llama3_inst_q8 = OllamaFunctions(model="llama3:8b-instruct-q8_0", format="json")
llama3_inst_fp16 = OllamaFunctions(model="llama3:8b-instruct-fp16", format="json")

### Connect to DB with Readonly role

In [4]:
db = db_connect.get_db()

#### Check connection

In [5]:
db.run("select * from passenger")

"[(1, 'John'), (2, 'James'), (3, 'Poul'), (4, 'Christofer'), (5, 'Superman'), (6, 'Donald'), (7, 'Douglas'), (8, 'Dwight'), (9, 'Earl'), (10, 'Edgar'), (11, 'Edmund'), (12, 'Edwin'), (13, 'Elliot'), (14, 'Eric'), (15, 'Ernest'), (16, 'Ethan'), (17, 'Ezekiel'), (18, 'Felix'), (19, 'Franklin'), (20, 'Frederick'), (21, 'Gabriel'), (22, 'Joseph'), (23, 'Joshua'), (24, 'Julian'), (25, 'Alice'), (26, 'Bob'), (27, 'Charlie'), (28, 'David'), (29, 'Emily'), (30, 'Frank'), (31, 'George'), (32, 'Helen'), (33, 'Irene'), (34, 'Jack'), (35, 'Kate'), (36, 'Leo'), (37, 'Mary'), (38, 'Nancy'), (39, 'Oliver'), (40, 'Paul'), (41, 'Qiana'), (42, 'Robert'), (43, 'Samantha'), (44, 'Thomas'), (45, 'Victoria')]"

-----

## Create agent

In [6]:
class GetInformationAboutDatabase(BaseModel):
    """Gives a comma-separated list of table names from the database."""

class GetInformationAboutTable(BaseModel):
    """Gives information about table, e.g. name of columns in table, type of columns, etc"""
    name_of_table: str = Field(description="The name of the table you want to get information about.")


get_information_about_database_tool = convert_to_ollama_tool(GetInformationAboutDatabase)
get_information_about_table_tool = convert_to_ollama_tool(GetInformationAboutTable)

tools_list = ["get_information_about_database_tool", "get_information_about_table_tool"]

In [7]:
def get_db_tables():
    return "Passenger, Employee"

In [8]:
def get_table_info(table_name: str):
    if table_name == "Passenger":
        return "Passenger:\n(passnger_name VARCHAR(60) PRIMARY KEY,\npassenger_age INT)"
    else:
        return "Employee:\n(employee_name VARCHAR(60) PRIMARY KEY,\nemployee_age INT)"

In [9]:
model_with_tools_1 = llama3_inst_q8.bind_tools(
    tools=[
        get_information_about_database_tool,
        DEFAULT_RESPONSE_FUNCTION
    ],
    tool_choice="any"
)

In [10]:
model_with_tools_2 = llama3_inst_q8.bind_tools(
    tools=[
        get_information_about_table_tool
    ],
    function_calls="get_information_about_table_tool"
)

In [11]:
prompt_1 = ChatPromptTemplate.from_messages([
    (
        "system", 
        "# Personality\nYou are a friendly AI assistant.\n\n" +
        "# Todo\nYou should answer the user\'s question or just have a conversation with them." +
        "If user want to talk with you use \"__conversational_response\"\n\n# Tools\n" +
        DEFAULT_SYSTEM_TEMPLATE
    ),
    ("user", "{question}")
])

In [12]:
prompt_2 = ChatPromptTemplate.from_messages([
    (
        "system", 
        "# Personality\nYou are a friendly AI assistant.\n\n" +
        "# Todo\nBased on the user's question and the available tables," +
        "select the tables you will need, but don't take too many. " +
        "Please select the minimum required amount.\n\n# Tools\n" +
        DEFAULT_SYSTEM_TEMPLATE
    ),
    ("user", "User question: {question}\nAvaible tables:{tables}")
])

In [13]:
model_with_prompt = prompt_1.partial(tools=tools_list) | model_with_tools_1 

In [14]:
def route_1(result):
    if isinstance(result, AgentFinish):
        return result.return_values['output']
    else:
        tools = { 
            "GetInformationAboutDatabase": get_db_tables,
        }
        return tools[result.tool](**result.tool_input)

In [15]:
def route_2(result):
    if isinstance(result, AgentFinish):
        return result.return_values['output']
    elif isinstance(result, str):
        return result
    else:
        tools = { 
            "GetInformationAboutTable": get_table_info,
        }
        tool_result = tools[result.tool](**result.tool_input)
        return (prompt_2.partial(tables=tool_result, tools=tools_list) | model_with_tools_2)

In [16]:
model = {"question": RunnablePassthrough()} | model_with_prompt | OpenAIFunctionsAgentOutputParser() | route_1 #| route_2

In [17]:
res = model.batch([
    {"question": "Hi"}, 
    {"question": "What information is there about the passengers in db?"}
])#, return_exceptions=True)#, config={'callbacks': [ConsoleCallbackHandler()]})

In [18]:
print(res)

["Hello! It's nice to talk to you. Is there something I can help you with or would you like to chat about something in particular?", "There is no specific database table for passenger information, but we can extract this information from other tables. For example, we have a 'flights' table which contains the flight numbers and passenger counts. We also have a 'bookings' table where we store the booking details of each passenger. If you want to know more about the passengers in specific flights or bookings, please provide me with more information on what exactly you are looking for."]
