In [2]:
from langchain_community.llms.ollama import Ollama

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase

from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.chains.sql_database import query

from langchain.agents import Tool
from langchain import agents

from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit 
from langchain.agents.agent_toolkits.sql import base

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

from langchain.globals import set_verbose
from langchain.globals import set_debug
set_verbose(True)
# set_debug(True)

from pswrd import PASSWORD_OF_DB
from pswrd import PASSWORD_FOR_VC_CREATOR

### Load models
llama = Ollama(model="llama2:13b", temperature=0.1, repeat_penalty=0)
llama_embeddings = OllamaEmbeddings(model="llama2:13b", temperature=0.1, repeat_penalty=0)
sqlcoder_embeddings = OllamaEmbeddings(model="sqlcoder:15b", temperature=0.1, repeat_penalty=0)
sqlcoder = Ollama(model='sqlcoder:15b', temperature=0.1, repeat_penalty=0)

### Save examples for few-shot prompt
examples = [
    {
        "input": "How many passengers are in the database?",
        "query": "SELECT COUNT(*) FROM public.\"passenger\";"
    },
    {
        "input": "What are the departure times of all flights?",
        "query": "SELECT \"time_out\" FROM public.\"trip\""
    },
    {
        "input": "What is John's place?",
        "query": "SELECT \"place\" FROM public.\"pass_in_trip\" JOIN public.\"passenger\" ON public.\"passenger\".\"id\" = public.\"pass_in_trip\".\"passenger\" WHERE public.\"passenger\".\"name\" = \'John\'"
    },
    {
        "input": "Give me all information about airlines",
        "query": "SELECT * FROM public.\"company\""
    },
    {
        "input": "Show me all the trips that are flying out today",
        "query": "SELECT * FROM public.\"trip\"\nWHERE EXTRACT(DAY FROM NOW()) = EXTRACT(DAY FROM \"time_out\")"
    },
    {
        "input": "Which planes depart from Washington?",
        "query": "SELECT \"plane\" FROM public.\"trip\" WHERE \"town_from\" = \'Washington\'"
    },
    {
        "input": "Print out the names of all the planes", 
        "query": "SELECT \"plane\" FROM public.\"trip\""
    },
    {
        "input": "How many people fly on Airbus?",
        "query": "SELECT COUNT(*) FROM public.\"pass_in_trip\" AS paip JOIN public.\"trip\" ON trip.\"id\" = paip.\"trip\" WHERE trip.\"plane\" = \'Airbus\'"
    },
]

### Create prompt
prefix = \
"""You are a silent PostreSQL expert. Given an input question, first create {top_k} syntactically correct PostreSQL query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves \"today\".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Don\'t explain, use the following format:

User input: {input}
SQL query: """

example_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

### Connect to DB with Readonly role
CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver="psycopg2",
    host="localhost",
    port=5433,
    database="llama-test-2",
    user="seq2sql_llama2_rag",
    password=PASSWORD_OF_DB
)
db = SQLDatabase.from_uri(CONNECTION_STRING)
db.run("SET ROLE pg_read_all_data");

#### Check connection
db.run("select * from passenger")

### Prompt
template =\
"""# Personality 
Act as SQL expert for answer to an office employee. 

# Task
Create and run SQL query to answer user\'s question. 

# Tools
Always keep in mind the tools that you can use: {tools}

# Format
Give final answer in following format:
User\'s input: \"<input>\"
Created SQL query: \"<sql_code>\"
Run query result: \"<sql_run_result>\"
Final answer: \"<final_answer>\"

# Tonality
Be strict, don't use emojis. Be careful and paranoid, who checks everything 3 times

# Atypical cases 
If there is not enough information from the database to answer, write \"I cannot answer\"

Let's start a chain of thoughts!

User's input: \"{input}\"
Thought: {agent_scratchpad}
Choose the one of [{tool_names}] you need and use it"""

prompt = PromptTemplate.from_template(template)

-----

In [3]:
toolkit = SQLDatabaseToolkit(llm=llama, db=db)

In [4]:

agent = base.create_sql_agent(
    llm=sqlcoder,
    toolkit=toolkit,
    agent_type=agents.AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=2,
    verbose=True
)

# agent = agents.create_react_agent(llama, toolkit.get_tools(), prompt)

# agent = agents.AgentExecutor(
#     agent=agent,
#     tools=toolkit.get_tools(),
#     verbose=True,
# )

In [5]:
# res1 = agent.invoke({
#     "input": "Select the names of all the people who are in the airline database",
# })

# print(res1)

In [6]:
res = sqlcoder.invoke("""Print all names of people in the airlines\n\nUse this database schema:
CREATE TABLE company (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT company_pkey PRIMARY KEY (id)
)


CREATE TABLE pass_in_trip (
	id BIGSERIAL NOT NULL, 
	trip BIGINT, 
	passenger BIGINT, 
	place VARCHAR(60), 
	CONSTRAINT pass_in_trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES passenger (id), 
	CONSTRAINT fk_trip FOREIGN KEY(trip) REFERENCES trip (id)
)


CREATE TABLE passenger (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (id)
)


CREATE TABLE trip (
	id BIGSERIAL NOT NULL, 
	company BIGINT, 
	plane VARCHAR(60), 
	town_from VARCHAR(60), 
	town_to VARCHAR(60), 
	time_out TIMESTAMP WITHOUT TIME ZONE, 
	time_in TIMESTAMP WITHOUT TIME ZONE, 
	CONSTRAINT trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_company FOREIGN KEY(company) REFERENCES company (id)
)""")

print(res)