# text2sql based on llama2
## Task: To get an sql query for a given database from a natural language query

Build a chain with:<br/> Question => LLM => SQL => DB => LLM => Answer

(This file is for experimentation and model creation)

### Imports

In [1]:
from langchain_community.llms.ollama import Ollama

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase

from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.chains.sql_database import query

from pswrd import PASSWORD_OF_DB
from pswrd import PASSWORD_FOR_VC_CREATOR

### Download llama2

In [2]:
model = Ollama(model="llama2", temperature=0)

## The first idea:
Is add a **prompt template** with the context of the database structure and a query for NL to `llama2`

#### Firstly add DB structure

In [3]:
DB_STRUCTURE = \
    ["""
    CREATE TABLE trip(
        id BIGSERIAL PRIMARY KEY,
        company BIGINT, 
        plane CHARACTER VARYING(60),
        town_from CHARACTER VARYING(60),
        town_to CHARACTER VARYING(60),
        time_out TIMESTAMP,
        time_in TIMESTAMP,
        CONSTRAINT FK_company FOREIGNT KEY (\"company\") REFERENCES public.company (id)
    ); 
    """,
     """
    CREATE TABLE company(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60) 
    );
    """,
     """
    CREATE TABLE pass_in_trip(
        id BIGSERIAL PRIMARY KEY,
        trip BIGINT,
        passenger BIGINT,
        place CHARACTER VARYING(60),
        CONSTRAINT FK_trip FOREIGNT KEY (\"trip\") REFERENCES public.trip (id),
        CONSTRAINT FK_passanger FOREIGNT KEY (\"passenger\") REFERENCES public.passenger (id)
    );
    """,
     """
    CREATE TABLE passenger(
        id BIGSERIAL PRIMARY KEY,
        name CHARACTER VARYING(60)
    );
    """
    ]

In [4]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test",
    user="pgvc_embeddings_creator",
    password=PASSWORD_FOR_VC_CREATOR,
)
COLLECTION_NAME = "text2sql_vc"

In [5]:
embeddings = OllamaEmbeddings(model="llama2")

In [6]:
structure_retriver = PGVector.from_texts(
    embedding=embeddings,
    texts=DB_STRUCTURE,
    collection_name=COLLECTION_NAME,
    connection_string=CONNECTION_STRING,
    pre_delete_collection=True,
    use_jsonb=True
).as_retriever()

#### Now create a template

In [7]:
template = \
"""
Translate the following query to sql using the following database structure. 
{structure}
As an answer, provide an sql query for postgresql.
Query to translate: {query}
"""

prompt_with_db_structure = PromptTemplate.from_template(template)

In [8]:
model_with_structure_context = (
    {"structure": structure_retriver, "query": RunnablePassthrough()}
    | prompt_with_db_structure
    | model
    | StrOutputParser()
)

In [9]:
res_query = model_with_structure_context.invoke("Select the names of all the people who are in the airline database")
res_query

' Sure! Here is an SQL query that translates to:\n```\nSELECT name\nFROM passenger\nJOIN pass_in_trip ON passenger.id = pass_in_trip.passenger;\n```\nExplanation:\n\n* The `SELECT` clause selects the `name` column from the `passenger` table.\n* The `JOIN` clause joins the `passenger` table with the `pass_in_trip` table on the `id` column. The `ON` clause specifies the join condition, which is `passenger.id = pass_in_trip.passenger`.\n* The `JOIN` clause returns all rows from both tables where the join condition is met.\n\nNote: In PostgreSQL, you can use the `AS` keyword to give an alias to a table or column, like in the `pass_in_trip` table.'

In [10]:
print(res_query)

 Sure! Here is an SQL query that translates to:
```
SELECT name
FROM passenger
JOIN pass_in_trip ON passenger.id = pass_in_trip.passenger;
```
Explanation:

* The `SELECT` clause selects the `name` column from the `passenger` table.
* The `JOIN` clause joins the `passenger` table with the `pass_in_trip` table on the `id` column. The `ON` clause specifies the join condition, which is `passenger.id = pass_in_trip.passenger`.
* The `JOIN` clause returns all rows from both tables where the join condition is met.

Note: In PostgreSQL, you can use the `AS` keyword to give an alias to a table or column, like in the `pass_in_trip` table.


The answer looks like it's true, but it's not. The model is hallucinating

## The second idea: Use Langchain sql query tamplate

### Create new connection, with readonly privileges

In [11]:
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test-2",
    user="seq2sql_llama2_rag",
    password=PASSWORD_OF_DB,
)
COLLECTION_NAME = "table_with_data_to_read"

In [12]:
db = SQLDatabase.from_uri(CONNECTION_STRING)

In [13]:
print(db.table_info)


CREATE TABLE company (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT company_pkey PRIMARY KEY (id)
)

/*
3 rows from company table:
id	name

*/


CREATE TABLE pass_in_trip (
	id BIGSERIAL NOT NULL, 
	trip BIGINT, 
	passenger BIGINT, 
	place VARCHAR(60), 
	CONSTRAINT pass_in_trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES passenger (id), 
	CONSTRAINT fk_trip FOREIGN KEY(trip) REFERENCES trip (id)
)

/*
3 rows from pass_in_trip table:
id	trip	passenger	place

*/


CREATE TABLE passenger (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (id)
)

/*
3 rows from passenger table:
id	name

*/


CREATE TABLE trip (
	id BIGSERIAL NOT NULL, 
	company BIGINT, 
	plane VARCHAR(60), 
	town_from VARCHAR(60), 
	town_to VARCHAR(60), 
	time_out TIMESTAMP WITHOUT TIME ZONE, 
	time_in TIMESTAMP WITHOUT TIME ZONE, 
	CONSTRAINT trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_company FOREIGN KEY(company) REFERENCES company (id

### Create seq2sql chain

In [14]:
sql_query_chain = query.create_sql_query_chain(model, db)

Check template in sql chain

In [15]:
print(sql_query_chain.get_prompts()[0].template)

You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to ru

In [16]:
res = sql_query_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database

SQLQuery: SELECT name FROM passenger WHERE EXISTS (SELECT 1 FROM pass_in_trip WHERE passenger = id);


The model is hallucinating and therefore gives an incorrect answer. It also ignores some of the requirements of the template.
<br/>Conclusion: **LLaMa2 is not up to the task due to the size**.

## The third idea: Try to use few-shots training strategy
#### Write examples with current DB structure

In [17]:
examples = [
    {
        "input": "How many passengers are in the database?",
        "query": "SELECT COUNT(*) FROM public.\"passenger\";"
    },
    {
        "input": "What are the departure times of all flights?",
        "query": "SELECT \"time_out\" FROM public.\"trip\""
    },
    {
        "input": "What is Jane's place?",
        "query": "SELECT \"place\" FROM public.\"pass_in_trip\" JOIN public.\"passenger\" ON public.\"passenger\".\"id\" = public.\"pass_in_trip\".\"passenger\" WHERE public.\"passenger\".\"name\" = \'John\'"
    },
    {
        "input": "Give me all information about airlines",
        "query": "SELECT * FROM public.\"company\""
    },
    {
        "input": "Show me all the trips that are flying out today",
        "query": "SELECT * FROM public.\"trip\"\nWHERE EXTRACT(DAY FROM NOW()) = EXTRACT(DAY FROM \"time_out\")"
    },
    {
        "input": "Which planes depart from Washington?",
        "query": "SELECT \"plane\" FROM public.\"trip\" WHERE \"town_from\" = \'Washington\'"
    },
    {
        "input": "Print out the names of all the planes", 
        "query": "SELECT \"plane\" FROM public.\"trip\""
    },
    {
        "input": "How many people fly on Airbus?",
        "query": "SELECT COUNT(*) FROM public.\"pass_in_trip\" AS paip JOIN public.\"trip\" ON trip.\"id\" = paip.\"trip\" WHERE trip.\"plane\" = \'Airbus\'"
    },
]

#### Create a template
Prefix and suffix from sql query template

In [18]:
prefix = \
"""You are a SQLite expert. Given an input question, first create {top_k} syntactically correct SQLite query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Use the following format:

User input: {input}
SQL query: """

few_shots_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

few_shots_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=few_shots_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["input", "table_info", "top_k"],
)

The request turns out to be very large, I'm not sure what LLaMa can handle, but it's worth a try

In [19]:
few_shots_chain = query.create_sql_query_chain(llm=model, db=db, prompt=few_shots_prompt)

In [20]:
res = few_shots_chain.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Based on the provided tables and queries, here are the results for each user input:

1. User input: How many passengers are in the database?
SQL query: SELECT COUNT(*) FROM public."passenger";
Results: id	name	COUNT(*)
0	John	3
0	Jane	2
0	Bob	1

2. User input: What are the departure times of all flights?
SQL query: SELECT "time_out" FROM public."trip";
Results: id	company	plane	town_from	town_to	time_out
0	Airbus	Washington	None	None	2023-03-14 10:00:00
0	Boeing	New York	None	None	2023-03-14 11:00:00
0	Airbus	Los Angeles	None	None	2023-03-14 12:00:00

3. User input: What is Jane's place?
SQL query: SELECT "place" FROM public."pass_in_trip" JOIN public."passenger" ON public."passenger"."id" = public."pass_in_trip"."passenger" WHERE public."passenger"."name" = 'Jane';
Results: id	name	place
0	Jane	Washington

4. User input: Give me all information about airlines
SQL query: SELECT * FROM public."company";
Results: id	name
0	Airbus
0	Boeing

5. User input: Show me all the trips that are fl

The model does not understand what is wanted from her. I think this is due to a disproportionately large request for a relatively small model

### Let's try to add a model that will clean the query

In [21]:
template = "You're a typo and grammar filter. Here is the user\'s request. \
    Correct all the typos, but don\'t change the word order: {question}\
        Leave only the answer in the following format and don't write anything other than that:\nAnswer: answer here"
clear_question_prompt = PromptTemplate.from_template(template)

In [22]:
sql_query_chain_with_cleaning = (
    {"question": RunnablePassthrough()}
    | clear_question_prompt
    | model
    | StrOutputParser()
)

sql_query_chain_with_cleaning = (
    {"question": sql_query_chain_with_cleaning}
    | sql_query_chain
)

In [23]:
res = sql_query_chain_with_cleaning.invoke({"question": "Select the names of all the people who are in the airline database"})
print(res)

Question: Select the names of all the people who are in the airline database

SQLQuery: SELECT name FROM passenger WHERE trip.id IN (SELECT id FROM trip);

This query selects all the rows from the `passenger` table where the `trip.id` matches any row in the `trip` table. This will return all the passengers who are on a trip.


After 5 launches, 1 was even completely successful, but for the most part the model is still hallucinating