# text2sql based on llama2
## Task: To get an sql query for a given database from a natural language query

Build a chain with:<br/> Question => LLM => SQL => DB => LLM => Answer

(This file is for experimentation and model creation)

### Imports

In [1]:
from langchain_community.llms.ollama import Ollama

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase

from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.chains.sql_database import query

from langchain.agents import Tool
from langchain import agents

from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 
from langchain_community.agent_toolkits.sql import base

from pswrd import PASSWORD_OF_DB
from pswrd import PASSWORD_FOR_VC_CREATOR

### Load llama2

In [2]:
model = Ollama(model="llama2", temperature=0)

## The 1st idea:
Is add a **prompt template** with the context of the database structure and a query for NL to `llama2`

#### Firstly add DB structure

In [3]:
DB_STRUCTURE = \
    ["""
    CREATE TABLE trip(
        id BIGSERIAL PRIMARY KEY,
        company BIGINT, 
        plane CHARACTER VARYING(60),
        town_from CHARACTER VARYING(60),
        town_to CHARACTER VARYING(60),
        time_out TIMESTAMP,
        time_in TIMESTAMP,
        CONSTRAINT FK_company FOREIGNT KEY (\"company\") REFERENCES public.company (id)
    ); 
    """,
     """
    CREATE TABLE company(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60) 
    );
    """,
     """
    CREATE TABLE pass_in_trip(
        id BIGSERIAL PRIMARY KEY,
        trip BIGINT,
        passenger BIGINT,
        place CHARACTER VARYING(60),
        CONSTRAINT FK_trip FOREIGNT KEY (\"trip\") REFERENCES public.trip (id),
        CONSTRAINT FK_passanger FOREIGNT KEY (\"passenger\") REFERENCES public.passenger (id)
    );
    """,
     """
    CREATE TABLE passenger(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60)
    );
    """
    ]

In [4]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test",
    user="pgvc_embeddings_creator",
    password=PASSWORD_FOR_VC_CREATOR,
)
COLLECTION_NAME = "text2sql_vc"

In [5]:
embeddings = OllamaEmbeddings(model="llama2")

In [6]:
structure_retriver = PGVector.from_texts(
    embedding=embeddings,
    texts=DB_STRUCTURE,
    collection_name=COLLECTION_NAME,
    connection_string=CONNECTION_STRING,
    pre_delete_collection=True,
    use_jsonb=True
).as_retriever()

#### Now create a template

In [7]:
template = \
"""
Translate the following query to sql using the following database structure. 
{structure}
As an answer, provide an sql query for postgresql.
Query to translate: {query}
"""

prompt_with_db_structure = PromptTemplate.from_template(template)

In [8]:
model_with_structure_context = (
    {"structure": structure_retriver, "query": RunnablePassthrough()}
    | prompt_with_db_structure
    | model
    | StrOutputParser()
)

In [9]:
res_query = model_with_structure_context.invoke("Select the names of all the people who are in the airline database")
res_query

' Sure! Here is an SQL query that translates to:\n```\nSELECT name\nFROM passenger\nJOIN pass_in_trip ON passenger.id = pass_in_trip.passenger;\n```\nExplanation:\n\n* The `SELECT` clause selects the `name` column from the `passenger` table.\n* The `JOIN` clause joins the `passenger` table with the `pass_in_trip` table on the `id` column. The `ON` clause specifies the join condition, which is `passenger.id = pass_in_trip.passenger`.\n* The `JOIN` clause returns all rows from both tables where the join condition is met.\n\nNote: In PostgreSQL, you can use the `AS` keyword to give an alias to a table or column, like in the `pass_in_trip` table.'

In [10]:
print(res_query)

 Sure! Here is an SQL query that translates to:
```
SELECT name
FROM passenger
JOIN pass_in_trip ON passenger.id = pass_in_trip.passenger;
```
Explanation:

* The `SELECT` clause selects the `name` column from the `passenger` table.
* The `JOIN` clause joins the `passenger` table with the `pass_in_trip` table on the `id` column. The `ON` clause specifies the join condition, which is `passenger.id = pass_in_trip.passenger`.
* The `JOIN` clause returns all rows from both tables where the join condition is met.

Note: In PostgreSQL, you can use the `AS` keyword to give an alias to a table or column, like in the `pass_in_trip` table.


The answer looks like it's true, but it's not. The model is hallucinating

## The 2nd idea: Use Langchain sql query tamplate

### Create new connection, with readonly privileges

In [3]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test-2",
    user="seq2sql_llama2_rag",
    password=PASSWORD_OF_DB,
)
COLLECTION_NAME = "table_with_data_to_read"

In [4]:
db = SQLDatabase.from_uri(CONNECTION_STRING)

In [13]:
print(db.table_info)


CREATE TABLE company (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT company_pkey PRIMARY KEY (id)
)

/*
3 rows from company table:
id	name

*/


CREATE TABLE pass_in_trip (
	id BIGSERIAL NOT NULL, 
	trip BIGINT, 
	passenger BIGINT, 
	place VARCHAR(60), 
	CONSTRAINT pass_in_trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES passenger (id), 
	CONSTRAINT fk_trip FOREIGN KEY(trip) REFERENCES trip (id)
)

/*
3 rows from pass_in_trip table:
id	trip	passenger	place

*/


CREATE TABLE passenger (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (id)
)

/*
3 rows from passenger table:
id	name

*/


CREATE TABLE trip (
	id BIGSERIAL NOT NULL, 
	company BIGINT, 
	plane VARCHAR(60), 
	town_from VARCHAR(60), 
	town_to VARCHAR(60), 
	time_out TIMESTAMP WITHOUT TIME ZONE, 
	time_in TIMESTAMP WITHOUT TIME ZONE, 
	CONSTRAINT trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_company FOREIGN KEY(company) REFERENCES company (id

### Create seq2sql chain

In [14]:
sql_query_chain = query.create_sql_query_chain(model, db)

Check template in sql chain

In [15]:
print(sql_query_chain.get_prompts()[0].template)

You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to ru

In [16]:
res = sql_query_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database

SQLQuery: SELECT name FROM passenger WHERE EXISTS (SELECT 1 FROM pass_in_trip WHERE passenger = id);


The model is hallucinating and therefore gives an incorrect answer. It also ignores some of the requirements of the template.
<br/>Conclusion: **LLaMa2 is not up to the task due to the size**.

## The 3rd idea: Try to use few-shots training strategy
#### Write examples with current DB structure

In [17]:
examples = [
    {
        "input": "How many passengers are in the database?",
        "query": "SELECT COUNT(*) FROM public.\"passenger\";"
    },
    {
        "input": "What are the departure times of all flights?",
        "query": "SELECT \"time_out\" FROM public.\"trip\""
    },
    {
        "input": "What is Jane's place?",
        "query": "SELECT \"place\" FROM public.\"pass_in_trip\" JOIN public.\"passenger\" ON public.\"passenger\".\"id\" = public.\"pass_in_trip\".\"passenger\" WHERE public.\"passenger\".\"name\" = \'John\'"
    },
    {
        "input": "Give me all information about airlines",
        "query": "SELECT * FROM public.\"company\""
    },
    {
        "input": "Show me all the trips that are flying out today",
        "query": "SELECT * FROM public.\"trip\"\nWHERE EXTRACT(DAY FROM NOW()) = EXTRACT(DAY FROM \"time_out\")"
    },
    {
        "input": "Which planes depart from Washington?",
        "query": "SELECT \"plane\" FROM public.\"trip\" WHERE \"town_from\" = \'Washington\'"
    },
    {
        "input": "Print out the names of all the planes", 
        "query": "SELECT \"plane\" FROM public.\"trip\""
    },
    {
        "input": "How many people fly on Airbus?",
        "query": "SELECT COUNT(*) FROM public.\"pass_in_trip\" AS paip JOIN public.\"trip\" ON trip.\"id\" = paip.\"trip\" WHERE trip.\"plane\" = \'Airbus\'"
    },
]

#### Create a template
Prefix and suffix from sql query template

In [18]:
prefix = \
"""You are a PostreSQL expert. Given an input question, first create {top_k} syntactically correct PostreSQL query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Use the following format:

User input: {input}
SQL query: """

few_shots_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

few_shots_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=few_shots_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

The request turns out to be very large, I'm not sure what LLaMa can handle, but it's worth a try

In [19]:
few_shots_chain = query.create_sql_query_chain(llm=model, db=db, prompt=few_shots_prompt)

In [20]:
res = few_shots_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Based on the provided tables and queries, here are the results for each user input:

1. User input: How many passengers are in the database?
SQL query: SELECT COUNT(*) FROM public."passenger";
Results: id	name	count

2. User input: What are the departure times of all flights?
SQL query: SELECT "time_out" FROM public."trip";
Results: id	company	plane	town_from	town_to	time_out	time_in

3. User input: What is Jane's place?
SQL query: SELECT "place" FROM public."pass_in_trip" JOIN public."passenger" ON public."passenger"."id" = public."pass_in_trip"."passenger" WHERE public."passenger"."name" = 'Jane';
Results: id	name	place

4. User input: Give me all information about airlines
SQL query: SELECT * FROM public."company";
Results: id	name

5. User input: Show me all the trips that are flying out today
SQL query: SELECT * FROM public."trip" WHERE EXTRACT(DAY FROM NOW()) = EXTRACT(DAY FROM "time_out");
Results: id	company	plane	town_from	town_to	time_out	time_in

6. User input: Which planes 

The model does not understand what is wanted from her. I think this is due to a disproportionately large request for a relatively small model

### Let's try to add a model that will clean the query

In [21]:
template = "You're a typo and grammar filter. Here is the user\'s request. \
    Correct all the typos, but don\'t change the word order: {question}\
        Leave only the answer in the following format and don't write anything other than that:\nAnswer: answer here"
clear_question_prompt = PromptTemplate.from_template(template)

In [22]:
sql_query_chain_with_cleaning = (
    {"question": RunnablePassthrough()}
    | clear_question_prompt
    | model
    | StrOutputParser()
)

sql_query_chain_with_cleaning = (
    {"question": sql_query_chain_with_cleaning}
    | sql_query_chain
)

In [23]:
res = sql_query_chain_with_cleaning.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database

SQLQuery: SELECT name FROM passenger WHERE CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES pass_in_trip;

Results:
id	name
1	Alice
2	Bob
3	Charlie

Answer: The names of all the people in the airline database are Alice, Bob, and Charlie.


After 5 launches, 1 was even completely successful, but for the most part the model is still hallucinating

## The 4th idea: Use langchain agents and tools

In [24]:
sql_db_chain = query.create_sql_query_chain(model, db)

db_tool = Tool(
    name="airlines_db_agent",
    func=sql_db_chain.invoke,
    description="Use for extaract data from database"
)

template = """Answer the following questions as best you can. You have access to the following tools: {tools}
Use the following format:
Question: the input question you must answer Thought: you should always think about what to do Action: the action to take, should be one of [{tool_names}] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question

Begin!
Question: {input} Thought:{agent_scratchpad}"""

prompt_with_tools = PromptTemplate.from_template(template)

db_agent = agents.create_react_agent(
    llm=model,
    tools=[db_tool],
    prompt=prompt_with_tools
)

In [25]:
res = agents.AgentExecutor\
    .from_agent_and_tools(db_agent, [db_tool])\
        .invoke({"input": "Select the names of all the people who are in the airline database"})
print(res)

KeyboardInterrupt: 

In [None]:
print(res["output"])

The names of the people in the airline database are:

I was able to access the airline database and retrieve the names of the people stored within it using the `airlines_db_agent` tool. The query I used to retrieve the desired information was: "SELECT name FROM people".

The result of the action is a list of names, which I have stored in a variable for further analysis.


still hallucinating...

## The 5th idea: Use SQLDatabaseToolkit

In [26]:
toolkit = SQLDatabaseToolkit(llm=model, db=db)

db_agent_toolkit = base.create_sql_agent(
    llm=model,
    toolkit=toolkit,
    max_iterations=5
)

In [27]:
res = db_agent_toolkit.invoke({"input": "Select the names of all the passenger who are in the airline database"})
print(res)

ValueError: An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Could not parse LLM output: `Action: sql_db_query
Input: "SELECT name FROM passengers;"`

## The 6th idea: Use `SQLCoder`, model based on `llama`

In [26]:
sqlcoder_model = Ollama(model='sqlcoder:15b')

In [19]:
toolkit = SQLDatabaseToolkit(llm=model, db=db)

db_agent_toolkit = base.create_sql_agent(
    llm=model,
    toolkit=toolkit
)

In [20]:
res = db_agent_toolkit.invoke({"input": "Select the names of all the passenger who are in the airline database"})
print(res)

{'input': 'Select the names of all the passenger who are in the airline database', 'output': 'Agent stopped due to iteration limit or time limit.'}


## The 7th idea: Use 2 models, `llama2` for NL and `SQLCoder` (*based on llama*) for generate SQL

In [5]:
template =\
"""
Here is the SQL {table_info}, select all columns and types to them and tables. Give the answer in the format:
Table: Column - The data type of this column
"""

get_names_template = PromptTemplate.from_template(template)

In [6]:
get_names_chain = (
    {"table_info": RunnablePassthrough()}
    | get_names_template
    | model
    | StrOutputParser()
)

In [47]:
res = get_names_chain.invoke({"table_info": db.table_info})
print(res)


Here is the information you requested:

Table: company
Column - The data type of this column
id - BIGSERIAL (not null, primary key)
name - VARCHAR(60)

Table: pass_in_trip
Column - The data type of this column
id - BIGSERIAL (not null, primary key)
trip - BIGINT (not null, foreign key to trip table)
passenger - BIGINT (not null, foreign key to passenger table)
place - VARCHAR(60) (not null)

Table: passenger
Column - The data type of this column
id - BIGSERIAL (not null, primary key)
name - VARCHAR(60) (not null)

Table: trip
Column - The data type of this column
id - BIGSERIAL (not null, primary key)
company - BIGINT (not null, foreign key to company table)
plane - VARCHAR(60) (not null)
town_from - VARCHAR(60) (not null)
town_to - VARCHAR(60) (not null)
time_out - TIMESTAMP WITHOUT TIME ZONE (not null)
time_in - TIMESTAMP WITHOUT TIME ZONE (not null)

Note: The data types are based on the information provided in the SQL code you provided, and may not reflect any additional columns o

In [7]:
template =\
"""
Here is the user's request: {input}. 
Think about what the user would like to receive. Then select the 5 most necessary column and table names that are needed in the query. 
Take the column and column names from {table_names}. 
In the response, leave only a paraphrased user request written in natural language
DONT WRITE SQL
Rephrase the user's request on natural language, but with the correct names:
"""

nl_query_template = PromptTemplate.from_template(template)

In [8]:
nl_query_chain = (
    {"table_names": get_names_chain, "input": RunnablePassthrough()}
    | nl_query_template
    | model
    | StrOutputParser()
)

In [9]:
res = nl_query_chain.invoke({
    "input": "Select the names of all the passenger who are in the airline database",
    "table_info": db.table_info})
print(res)

User Request: "Please provide me with the names of all passengers who are part of airline database."

Necessary Columns and Table Names:

1. `id` (primary key) - From `company`, `pass_in_trip`, and `passenger` tables.
2. `name` - From `company`, `pass_in_trip`, and `passenger` tables.
3. `trip` - From `pass_in_trip` table.
4. `passenger` - From `passenger` table.
5. `place` - From `pass_in_trip` table.

Note: The above list includes the columns and tables that are necessary to answer the user's request based on the provided SQL query.


In [17]:
sqlcoder_chain = (
    {"question": nl_query_chain}
    | query.create_sql_query_chain(llm=sqlcoder_model, db=db)
    | StrOutputParser()
)

In [18]:
res = sqlcoder_chain.invoke({"input": "Select the names of all the passenger who are in the airline database"})
print(res)

SELECT p.name, c.name AS company_name FROM passenger p JOIN flight f ON p.flightNumber = f.flightNumber JOIN company c ON f.company = c.id; SQLResult:                                                                      Answer: Please find me with the names of all passengers who are in the database.


## The 8th idea: Use `SQLDatabaseToolkit` with 2 models

In [30]:
toolkit = SQLDatabaseToolkit(llm=sqlcoder_model, db=db)

sqlcoder_toolkit = base.create_sql_agent(
    llm=model,
    toolkit=toolkit
)

In [31]:
res = sqlcoder_toolkit.invoke({"input": "Select the name of all the passenger who are in the airline database"})
print(res)

{'input': 'Select the name of all the passenger who are in the airline database', 'output': 'Agent stopped due to iteration limit or time limit.'}
