# SQL agent

In [12]:
import db_connect 
import vc_connect 
from examples import FEW_SHOT_EXAMPLES

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains.sql_database import query

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.agents import Tool
from langchain import agents

from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.sql import base

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

### Load models

In [13]:
llama = Ollama(model="llama2:13b", temperature=0.25, repeat_penalty=1)
llama_embeddings = OllamaEmbeddings(model="llama2:13b", temperature=0.25, repeat_penalty=1)

llama3 = Ollama(model="llama3:8b")#, temperature=0.25, top_p=0.8)
llama3_embeddings = OllamaEmbeddings(model="llama3:8b", temperature=0.25)

sqlcoder_embeddings = OllamaEmbeddings(model="sqlcoder:15b", temperature=0.25, repeat_penalty=1)
sqlcoder = Ollama(model='sqlcoder:15b', temperature=0.25, repeat_penalty=1)

### Create prompt

In [14]:
prefix = \
"""You are a silent PostreSQL expert. Given an input question, first create {top_k} syntactically correct PostreSQL query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves \"today\".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Don\'t explain, use the following format:

User input: {input}
SQL query: """

example_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

### Connect to DB with Readonly role

In [15]:
db = db_connect.get_db()

#### Check connection

In [16]:
db.run("select * from passenger")

"[(16, 'John'), (17, 'James'), (18, 'Poul'), (19, 'Christofer'), (20, 'Superman')]"

### Connect example selector

-----

In [17]:
toolkit = SQLDatabaseToolkit(llm=llama3, db=db)

agent = base.create_sql_agent(
    llm=llama3,
    toolkit=toolkit,
    agent_type=agents.AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    agent_executor_kwargs={
        "handle_parsing_errors": "Check you output and make sure it conforms! Do not output an action and a final answer at the same time. Else output only final answer!"
    }
)

In [18]:
res = agent.invoke({
    "input": "Select the names of all the people who are in the airline database"
})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI'm excited to start answering questions with these tools!

Thought: To get the names of all people in the airline database, I'll need to query the correct table(s) and retrieve the necessary columns.

Action: sql_db_list_tables

Action Input: Empty string (no input needed[0m[38;5;200m[1;3mcompany, pass_in_trip, passenger, trip[0m[32;1m[1;3mLet's get started.

Thought: The list of tables includes "passenger", which seems relevant for querying people who are in the airline database.
Action: sql_db_schema
Action Input: passenger[0m[33;1m[1;3m
CREATE TABLE passenger (
	passenger_id BIGSERIAL NOT NULL, 
	passenger_name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (passenger_id)
)

/*
3 rows from passenger table:
passenger_id	passenger_name
16	John
17	James
18	Poul
*/[0m[32;1m[1;3mThe schema for the "passenger" table shows that it has columns "passenger_id" and "passenger_name". This suggests that I can us

KeyboardInterrupt: 

In [None]:
print(res)

{'input': 'Select the names of all the people who are in the airline database', 'output': 'Agent stopped due to iteration limit or time limit.'}


In [None]:
print(res['output'])

Agent stopped due to iteration limit or time limit.
