# Developing the theme of few-shots

## Init (imports, load models...)

### Necessary imports

In [1]:
import db_connect
import vc_connect
from examples import FEW_SHOT_EXAMPLES

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains.sql_database import query

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.agents import Tool
from langchain import agents

from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.sql import base

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

from langchain.globals import set_verbose
set_verbose(True)

### Load models

In [2]:
llama = Ollama(model="llama2:13b", temperature=0.25, repeat_penalty=1)
llama_embeddings = OllamaEmbeddings(model="llama2:13b", temperature=0.25, repeat_penalty=1)
sqlcoder_embeddings = OllamaEmbeddings(model="sqlcoder:15b", temperature=0.25, repeat_penalty=1)
sqlcoder = Ollama(model='sqlcoder:15b', temperature=0.25, repeat_penalty=1)

### Create prompt

In [3]:
prefix = \
"""You are a silent PostreSQL expert. Given an input question, first create {top_k} syntactically correct PostreSQL query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves \"today\".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Don\'t explain, use the following format:

User input: {input}
SQL query: """

example_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

### Connect to DB with Readonly role

In [4]:
db = db_connect.get_db()

#### Check connection

In [5]:
db.run("select * from passenger")

"[(16, 'John'), (17, 'James'), (18, 'Poul'), (19, 'Christofer'), (20, 'Superman')]"

### Connect example selector

In [6]:
example_selector = vc_connect.get_selector(llama_embeddings)

## Simple few-shots by using `create_sql_query_chain`

In [7]:
few_shots_prompt = FewShotPromptTemplate(
    examples=FEW_SHOT_EXAMPLES,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

In [8]:
few_shots_chain = query.create_sql_query_chain(llm=llama, db=db, prompt=few_shots_prompt)

In [9]:
res = few_shots_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

User input: Select the names of all the people who are in the airline database

SQL query: SELECT "passenger"."passenger_name" FROM public."passenger"

Note: This query will return all the names of the people in the airline database.


In [10]:
db.run("SELECT \"passenger\".\"passenger_name\" FROM public.\"passenger\"")

"[('John',), ('James',), ('Poul',), ('Christofer',), ('Superman',)]"

In [11]:
res = few_shots_chain.invoke({"question": "Print the names of all airlines"})
print(res)

User input: Print the names of all airlines

SQL query: SELECT "company_name" FROM public."company"

Note: The above query will return all the company names in the "company" table.


In [12]:
db.run("SELECT \"company_name\" FROM public.\"company\"")

"[('American Airlines',)]"

In [13]:
res = few_shots_chain.invoke({"question": "Print the names of people that end in \"man\""})
print(res)

User input: Print the names of people that end in "man"
SQL query: SELECT "passenger_name" FROM public."passenger" WHERE "passenger_name" LIKE '%man';


In [14]:
db.run("SELECT \"passenger_name\" FROM public.\"passenger\" WHERE \"passenger_name\" LIKE \'%man\';")

"[('Superman',)]"

## Few-shots with `SemanticSimilarityExampleSelector`

### Connect to another DB to create embeddings

In [15]:
few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

In [16]:
few_shots_selector_chain = query.create_sql_query_chain(llm=llama, db=db, prompt=few_shots_prompt)

In [17]:
res = few_shots_selector_chain.invoke({
    "question": "Select the names of all the people who are in the airline database", 
    "top_k": "1"
})
print(res)

User input: Select the names of all the people who are in the airline database

SQL query: SELECT passenger_name FROM passenger;


In [18]:
db.run("SELECT passenger_name FROM passenger;")

"[('John',), ('James',), ('Poul',), ('Christofer',), ('Superman',)]"

In [19]:
res = few_shots_selector_chain.invoke({
    "question": "Print the names of all airlines", 
    "top_k": "1"
})
print(res)

User input: Print the names of all airlines

SQL query: SELECT company_name FROM company;

This query will retrieve the names of all airlines in the "company" table.


In [25]:
db.run("SELECT company_name FROM company;")

"[('American Airlines',)]"

In [20]:
res = few_shots_selector_chain.invoke({
    "question": "Print the names of people that end in \"man\"", 
    "top_k": "1"
})
print(res)

User input: Print the names of people that end in "man"

SQL query: SELECT passenger_name FROM passenger WHERE passenger_name LIKE '%man';


In [26]:
db.run("SELECT passenger_name FROM passenger WHERE passenger_name LIKE \'%man\';")

"[('Superman',)]"

## Few-shots selector with SQLCoder

In [21]:
few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

In [22]:
few_shots_selector_chain = query.create_sql_query_chain(llm=sqlcoder, db=db, prompt=few_shots_prompt)

In [23]:
res = few_shots_selector_chain.invoke({
    "question": "Select the names of all the people who are in the airline database", 
    "top_k": "1"
})
print(res)

```
<|endoftext|>


In [None]:
db.run("SELECT name FROM passenger")

"[('John',), ('James',), ('Poul',), ('Christofer',), ('Superman',)]"

In [24]:
res = few_shots_selector_chain.invoke({
    "question": "Print the names of all airlines", 
    "top_k": "1"
})
print(res)

Your response should be a comma-separated list of airline names.

Sample output:

"American Airlines", "Delta Air Lines"

Explanation:

The user wants to know the names of all airlines. This information is available in the 'company' table, which contains a column named 'company_name' for the name of the airline. We can simply select all the values from this column to get the names of all airlines.

The 'company_id' column in the 'trip' table is also useful for determining which airlines are involved in the flights. We can use this column to join the 'company' table with the 'trip' table and filter out the rows where the 'company_id' is 4. This will give us the names of the airlines that are involved in the flights.

The 'passenger_id' column in the 'pass_in_trip' table is also useful for determining which airlines are involved in the flights. We can use this column to join the 'passenger' table with the 'pass_in_trip' table and filter out the rows where the 'passenger_id' is 18. This w

In [None]:
db.run("SELECT name FROM company")

"[('American Airlines',)]"

In [None]:
res = few_shots_selector_chain.invoke({
    "question": "Print the names of people that end in \"man\"", 
    "top_k": "1"
})
print(res)

```
<|endoftext|>


In [None]:
db.run("SELECT name FROM passenger WHERE name LIKE \'%man\'")

"[('Superman',)]"

## Few-shots as tool in agent

In [None]:
tools = SQLDatabaseToolkit(llm=llama, db=db)

example_tool = Tool(
    name="sql_examples",
    func=example_selector.select_examples,
    description="Input this tool a user question to get most semantically similar examples. Always use this tool before generate a query.",
)

sql_suffix=\
"""Begin!

Question: {input}
Thought: I should look at semantically similar examples. So I should use sql_examples from tools:
{agent_scratchpad}"""

agent = base.create_sql_agent(
    llm=llama,
    toolkit=tools,
    extra_tools=[example_tool],
    suffix=sql_suffix
)

In [None]:
res = agent.invoke({
    "input": "Select the names of all the people who are in the airline database", 
    "top_k": "1"
})
print(res)



[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3m
SELECT name FROM passenger WHERE status = 'active'
Double check the postgresql query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only.

SQL Query: [0m

[1m> Finished chain.[0m


## Agent prompt with few-shots

In [None]:
suffix = "\n\n{input}"

few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    suffix=suffix,
    input_variables=["input"],
)

In [None]:
tools = SQLDatabaseToolkit(llm=sqlcoder, db=db)

agent = base.create_sql_agent(
    llm=llama,
    toolkit=tools
)

In [None]:
agent_chain = (
    {"input": RunnablePassthrough()}
    | few_shots_prompt
    | agent
    | StrOutputParser()
)

In [None]:
res = agent_chain.invoke("Select the names of all the people who are in the airline database")
print(res)

KeyboardInterrupt: 

## My own agent

### Create prompt

In [None]:
template =\
"""# Personality 
Act as SQL expert for answer to an office employee. 

# Task
Create and run SQL query to answer user\'s question. 

# Tools
Always keep in mind the tools that you can use: {tools}

# Format
Give final answer in following format:
User\'s input: \"<input>\"
Created SQL query: \"<sql_code>\"
Run query result: \"<sql_run_result>\"
Final answer: \"<final_answer>\"

# Tonality
Be strict, don't use emojis. Be careful and paranoid, who checks everything 3 times

# Atypical cases 
If there is not enough information from the database to answer, write \"I cannot answer\"

Let's start a chain of thoughts!

User's input: \"{input}\"
Thought: {agent_scratchpad}
Choose the one of [{tool_names}] you need and use it"""

prompt = PromptTemplate.from_template(template)

### Define tools

#### Tool for create SQL

In [None]:
tool_prefix =\
"""# Personality 
Act as PostgreSQL expert for answer to an office employee. 

# Task
Create correct PostgreSQL code based on user\'s question.

# Examples
"""

tool_example_prompt = PromptTemplate.from_template(
    "## Example\nUser input: {input}\nSQL query: {query}\n"
)

tool_suffix=\
"""
# Format
Leave ONLY the PostgreSQL CODE in the answer. After it, insert <|endoftext|>."""

few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=tool_example_prompt,
    prefix=tool_prefix,
    suffix=tool_suffix,
    input_variables=["input"],
)

tool_chain = (
    {"input": RunnablePassthrough()} 
    | few_shots_prompt 
    | sqlcoder 
)

sql_creator = Tool(
    name="sql_creator",
    func=tool_chain.invoke,
    description="Input to this tool an user input. Output of this tool is SQL code. Use this tool to create an SQL code based on user\'s input."
)

In [None]:
few_shots_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

In [None]:
few_shots_chain = query.create_sql_query_chain(llm=llama, db=db, prompt=few_shots_prompt)

sql_creator = Tool(
    name="sql_creator",
    func=tool_chain.invoke,
    description="Input to this tool an user input. Output of this tool is SQL code. Use this tool to create an SQL code based on user\'s input."
)

In [None]:
tools = SQLDatabaseToolkit(llm=sqlcoder, db=db).get_tools() + [sql_creator]

agent = agents.create_react_agent(
    llm=llama,
    tools=tools,
    prompt=prompt,
    output_parser=StrOutputParser(),
)

agent = agents.AgentExecutor(agent=agent, tools=tools, verbose=True)

In [None]:
from langchain.globals import set_debug
set_debug(False)

In [None]:
res1 = agent.invoke({
    "input": "Select the names of all the people who are in the airline database",
})
print(res1)



[1m> Entering new AgentExecutor chain...[0m


Error in StdOutCallbackHandler.on_agent_action callback: AttributeError("'str' object has no attribute 'log'")


AttributeError: 'str' object has no attribute 'tool'

In [None]:
res2 = agent.invoke({
    "input": "Print the names of all airlines",
})
print(res2)

KeyboardInterrupt: 