# text2sql based on llama2
## Task: To get an sql query for a given database from a natural language query

Build a chain with:<br/> Question => LLM => SQL => DB => LLM => Answer

(This file is for experimentation and model creation)

### Imports

In [13]:
from langchain_community.llms.ollama import Ollama

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase

from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.chains.sql_database import query

from langchain.agents import Tool
from langchain import agents

from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 
from langchain_community.agent_toolkits.sql import base

from pswrd import PASSWORD_OF_DB
from pswrd import PASSWORD_FOR_VC_CREATOR

### Load llama2

In [14]:
model = Ollama(model="llama2:13b", temperature=0)

## The first idea:
Is add a **prompt template** with the context of the database structure and a query for NL to `llama2`

#### Firstly add DB structure

In [15]:
DB_STRUCTURE = \
    ["""
    CREATE TABLE trip(
        id BIGSERIAL PRIMARY KEY,
        company BIGINT, 
        plane CHARACTER VARYING(60),
        town_from CHARACTER VARYING(60),
        town_to CHARACTER VARYING(60),
        time_out TIMESTAMP,
        time_in TIMESTAMP,
        CONSTRAINT FK_company FOREIGNT KEY (\"company\") REFERENCES public.company (id)
    ); 
    """,
     """
    CREATE TABLE company(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60) 
    );
    """,
     """
    CREATE TABLE pass_in_trip(
        id BIGSERIAL PRIMARY KEY,
        trip BIGINT,
        passenger BIGINT,
        place CHARACTER VARYING(60),
        CONSTRAINT FK_trip FOREIGNT KEY (\"trip\") REFERENCES public.trip (id),
        CONSTRAINT FK_passanger FOREIGNT KEY (\"passenger\") REFERENCES public.passenger (id)
    );
    """,
     """
    CREATE TABLE passenger(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60)
    );
    """
    ]

In [16]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test",
    user="pgvc_embeddings_creator",
    password=PASSWORD_FOR_VC_CREATOR,
)
COLLECTION_NAME = "text2sql_vc"

In [17]:
embeddings = OllamaEmbeddings(model="llama2")

In [18]:
structure_retriver = PGVector.from_texts(
    embedding=embeddings,
    texts=DB_STRUCTURE,
    collection_name=COLLECTION_NAME,
    connection_string=CONNECTION_STRING,
    pre_delete_collection=True,
    use_jsonb=True
).as_retriever()

#### Now create a template

In [19]:
template = \
"""
Translate the following query to sql using the following database structure. 
{structure}
As an answer, provide an sql query for postgresql.
Query to translate: {query}
"""

prompt_with_db_structure = PromptTemplate.from_template(template)

In [20]:
model_with_structure_context = (
    {"structure": structure_retriver, "query": RunnablePassthrough()}
    | prompt_with_db_structure
    | model
    | StrOutputParser()
)

In [21]:
res_query = model_with_structure_context.invoke("Select the names of all the people who are in the airline database")
res_query

' Sure! Here\'s the SQL query to translate the given query to PostgreSQL:\n\nSELECT passenger.name FROM passenger \nJOIN pass_in_trip ON passenger.id = pass_in_trip.passenger \nJOIN trip ON pass_in_trip.trip = trip.id;\n\nExplanation:\n\n1. We start by selecting all the rows from the "passenger" table.\n2. We then join the "passenger" table with the "pass_in_trip" table on the "id" column. This gives us all the rows where a passenger is in a trip.\n3. Finally, we join the "pass_in_trip" table with the "trip" table on the "trip" column. This gives us all the trips that a particular passenger is in.\n\nThe query above will give you the names of all the people who are in the airline database.'

In [22]:
print(res_query)

 Sure! Here's the SQL query to translate the given query to PostgreSQL:

SELECT passenger.name FROM passenger 
JOIN pass_in_trip ON passenger.id = pass_in_trip.passenger 
JOIN trip ON pass_in_trip.trip = trip.id;

Explanation:

1. We start by selecting all the rows from the "passenger" table.
2. We then join the "passenger" table with the "pass_in_trip" table on the "id" column. This gives us all the rows where a passenger is in a trip.
3. Finally, we join the "pass_in_trip" table with the "trip" table on the "trip" column. This gives us all the trips that a particular passenger is in.

The query above will give you the names of all the people who are in the airline database.


The answer looks like it's true, but it's not. The model is hallucinating

## The second idea: Use Langchain sql query tamplate

### Create new connection, with readonly privileges

In [23]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test-2",
    user="seq2sql_llama2_rag",
    password=PASSWORD_OF_DB,
)
COLLECTION_NAME = "table_with_data_to_read"

In [24]:
db = SQLDatabase.from_uri(CONNECTION_STRING)

In [25]:
print(db.table_info)


CREATE TABLE company (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT company_pkey PRIMARY KEY (id)
)

/*
3 rows from company table:
id	name

*/


CREATE TABLE pass_in_trip (
	id BIGSERIAL NOT NULL, 
	trip BIGINT, 
	passenger BIGINT, 
	place VARCHAR(60), 
	CONSTRAINT pass_in_trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES passenger (id), 
	CONSTRAINT fk_trip FOREIGN KEY(trip) REFERENCES trip (id)
)

/*
3 rows from pass_in_trip table:
id	trip	passenger	place

*/


CREATE TABLE passenger (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (id)
)

/*
3 rows from passenger table:
id	name

*/


CREATE TABLE trip (
	id BIGSERIAL NOT NULL, 
	company BIGINT, 
	plane VARCHAR(60), 
	town_from VARCHAR(60), 
	town_to VARCHAR(60), 
	time_out TIMESTAMP WITHOUT TIME ZONE, 
	time_in TIMESTAMP WITHOUT TIME ZONE, 
	CONSTRAINT trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_company FOREIGN KEY(company) REFERENCES company (id

### Create seq2sql chain

In [26]:
sql_query_chain = query.create_sql_query_chain(model, db)

Check template in sql chain

In [27]:
print(sql_query_chain.get_prompts()[0].template)

You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to ru

In [16]:
res = sql_query_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database

SQLQuery: SELECT name FROM passenger WHERE EXISTS (SELECT 1 FROM pass_in_trip WHERE passenger = id);


The model is hallucinating and therefore gives an incorrect answer. It also ignores some of the requirements of the template.
<br/>Conclusion: **LLaMa2 is not up to the task due to the size**.

## The third idea: Try to use few-shots training strategy
#### Write examples with current DB structure

In [30]:
examples = [
    {
        "input": "How many passengers are in the database?",
        "query": "SELECT COUNT(*) FROM public.\"passenger\";"
    },
    {
        "input": "What are the departure times of all flights?",
        "query": "SELECT \"time_out\" FROM public.\"trip\""
    },
    {
        "input": "What is Jane's place?",
        "query": "SELECT \"place\" FROM public.\"pass_in_trip\" JOIN public.\"passenger\" ON public.\"passenger\".\"id\" = public.\"pass_in_trip\".\"passenger\" WHERE public.\"passenger\".\"name\" = \'John\'"
    },
    {
        "input": "Give me all information about airlines",
        "query": "SELECT * FROM public.\"company\""
    },
    {
        "input": "Show me all the trips that are flying out today",
        "query": "SELECT * FROM public.\"trip\"\nWHERE EXTRACT(DAY FROM NOW()) = EXTRACT(DAY FROM \"time_out\")"
    },
    {
        "input": "Which planes depart from Washington?",
        "query": "SELECT \"plane\" FROM public.\"trip\" WHERE \"town_from\" = \'Washington\'"
    },
    {
        "input": "Print out the names of all the planes", 
        "query": "SELECT \"plane\" FROM public.\"trip\""
    },
    {
        "input": "How many people fly on Airbus?",
        "query": "SELECT COUNT(*) FROM public.\"pass_in_trip\" AS paip JOIN public.\"trip\" ON trip.\"id\" = paip.\"trip\" WHERE trip.\"plane\" = \'Airbus\'"
    },
]

#### Create a template
Prefix and suffix from sql query template

In [31]:
prefix = \
"""You are a SQLite expert. Given an input question, first create {top_k} syntactically correct SQLite query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Use the following format:

User input: {input}
SQL query: """

few_shots_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

few_shots_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=few_shots_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

The request turns out to be very large, I'm not sure what LLaMa can handle, but it's worth a try

In [32]:
few_shots_chain = query.create_sql_query_chain(llm=model, db=db, prompt=few_shots_prompt)

In [33]:
res = few_shots_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

User input: Select the names of all the people who are in the airline database

SQL query: SELECT name FROM passenger;

Explanation: This query retrieves the names of all the people in the "passenger" table. Since we want to retrieve all the names, we don't need to use any condition or join. Simply selecting the "name" column from the "passenger" table will give us the desired output.


The model does not understand what is wanted from her. I think this is due to a disproportionately large request for a relatively small model

### Let's try to add a model that will clean the query

In [34]:
template = "You're a typo and grammar filter. Here is the user\'s request. \
    Correct all the typos, but don\'t change the word order: {question}\
        Leave only the answer in the following format and don't write anything other than that:\nAnswer: answer here"
clear_question_prompt = PromptTemplate.from_template(template)

In [35]:
sql_query_chain_with_cleaning = (
    {"question": RunnablePassthrough()}
    | clear_question_prompt
    | model
    | StrOutputParser()
)

sql_query_chain_with_cleaning = (
    {"question": sql_query_chain_with_cleaning}
    | sql_query_chain
)

In [36]:
res = sql_query_chain_with_cleaning.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database.

SQLQuery: SELECT name FROM passenger;


After 5 launches, 1 was even completely successful, but for the most part the model is still hallucinating

## The fourth idea: Use langchain agents and tools

In [37]:
sql_db_chain = query.create_sql_query_chain(model, db)

db_tool = Tool(
    name="airlines_db_agent",
    func=sql_db_chain.invoke,
    description="Use for extaract data from database"
)

template = """Answer the following questions as best you can. You have access to the following tools: {tools}
Use the following format:
Question: the input question you must answer Thought: you should always think about what to do Action: the action to take, should be one of [{tool_names}] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question

Begin!
Question: {input} Thought:{agent_scratchpad}"""

prompt_with_tools = PromptTemplate.from_template(template)

db_agent = agents.create_react_agent(
    llm=model,
    tools=[db_tool],
    prompt=prompt_with_tools
)

In [38]:
res = agents.AgentExecutor\
    .from_agent_and_tools(db_agent, [db_tool])\
        .invoke({"input": "Select the names of all the people who are in the airline database"})
print(res)

ValueError: An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Could not parse LLM output: `
Thought: To select the names of all the people in the airline database, I can use the `airlines_db_agent` tool to query the database.

Action: Use `airlines_db_agent` to query the database and retrieve a list of all the names in the airline database.

Input: None`

In [None]:
print(res["output"])

The names of the people in the airline database are:

I was able to access the airline database and retrieve the names of the people stored within it using the `airlines_db_agent` tool. The query I used to retrieve the desired information was: "SELECT name FROM people".

The result of the action is a list of names, which I have stored in a variable for further analysis.


still hallucinating...

## The fifth idea: Use SQLDatabaseToolkit

In [39]:
toolkit = SQLDatabaseToolkit(llm=model, db=db)

prefix =\
"""
You are an SQL expert. Your task is to create a correct SQL query for {dialect}. 
The first thing you need to do from the user's request is to highlight the names of columns and tables that you will use later. 
Next, make a query, thinking about your actions step by step. 
Choose the {top_k} best queries and choose the best one from them. 

Next, launch it and respond to the user based on the data received. 
In the format: 
Query: 
Answer:
"""

db_agent_toolkit = base.create_sql_agent(
    llm=model,
    toolkit=toolkit,
    max_iterations=5,
    prefix=prefix
)

In [40]:
res = db_agent_toolkit.invoke({"input": "Select the names of all the passenger who are in the airline database"})
print(res)

{'input': 'Select the names of all the passenger who are in the airline database', 'output': 'Agent stopped due to iteration limit or time limit.'}
