# SQL agent

In [None]:
import db_connect
import vc_connect
from examples import FEW_SHOT_EXAMPLES

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_community.vectorstores.pgvector import PGVector
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains.sql_database import query

from langchain_core.prompts import FewShotPromptTemplate
from langchain_core.prompts import PromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from langchain.agents import Tool
from langchain import agents

from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.sql import base

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

from langchain.globals import set_verbose
set_verbose(True)

### Load models

In [None]:
llama = Ollama(model="llama2:13b", temperature=0.25, repeat_penalty=1)
llama_embeddings = OllamaEmbeddings(model="llama2:13b", temperature=0.25, repeat_penalty=1)
sqlcoder_embeddings = OllamaEmbeddings(model="sqlcoder:15b", temperature=0.25, repeat_penalty=1)
sqlcoder = Ollama(model='sqlcoder:15b', temperature=0.25, repeat_penalty=1)

### Create prompt

In [None]:
prefix = \
"""You are a silent PostreSQL expert. Given an input question, first create {top_k} syntactically correct PostreSQL query to run, then look at the results and take most correct.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves \"today\".
Below are a number of examples of questions and their corresponding SQL queries."""

suffix = \
"""Only use the following tables:
{table_info}

Don\'t explain, use the following format:

User input: {input}
SQL query: """

example_prompt = PromptTemplate.from_template(
    "User input: {input}\nSQL query: {query}"
)

### Connect to DB with Readonly role

In [None]:
db = db_connect.get_db()

#### Check connection

In [None]:
db.run("select * from passenger")

"[(16, 'John'), (17, 'James'), (18, 'Poul'), (19, 'Christofer'), (20, 'Superman')]"

### Connect example selector

In [None]:
example_selector = vc_connect.get_selector(llama_embeddings)

Exception: Failed to create vector extension: 'utf-8' codec can't decode byte 0xc2 in position 61: invalid continuation byte

-----

In [12]:
# toolkit = SQLDatabaseToolkit(llm=llama, db=db)

In [13]:

# agent = base.create_sql_agent(
#     llm=sqlcoder,
#     toolkit=toolkit,
#     agent_type=agents.AgentType.ZERO_SHOT_REACT_DESCRIPTION,
#     top_k=2,
#     verbose=True
# )

# agent = agents.create_react_agent(llama, toolkit.get_tools(), prompt)

# agent = agents.AgentExecutor(
#     agent=agent,
#     tools=toolkit.get_tools(),
#     verbose=True,
# )

In [14]:
# res1 = agent.invoke({
#     "input": "Select the names of all the people who are in the airline database",
# })

# print(res1)

In [15]:
res = sqlcoder.invoke("""Use this database schema:
CREATE TABLE company (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT company_pkey PRIMARY KEY (id)
)


CREATE TABLE pass_in_trip (
	id BIGSERIAL NOT NULL, 
	trip BIGINT, 
	passenger BIGINT, 
	place VARCHAR(60), 
	CONSTRAINT pass_in_trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_passanger FOREIGN KEY(passenger) REFERENCES passenger (id), 
	CONSTRAINT fk_trip FOREIGN KEY(trip) REFERENCES trip (id)
)


CREATE TABLE passenger (
	id BIGSERIAL NOT NULL, 
	name VARCHAR(60), 
	CONSTRAINT passenger_pkey PRIMARY KEY (id)
)


CREATE TABLE trip (
	id BIGSERIAL NOT NULL, 
	company BIGINT, 
	plane VARCHAR(60), 
	town_from VARCHAR(60), 
	town_to VARCHAR(60), 
	time_out TIMESTAMP WITHOUT TIME ZONE, 
	time_in TIMESTAMP WITHOUT TIME ZONE, 
	CONSTRAINT trip_pkey PRIMARY KEY (id), 
	CONSTRAINT fk_company FOREIGN KEY(company) REFERENCES company (id)
)
Print all names of people in airlines""")

print(res)

 that fly from London to Paris.
'''

# import libraries
import pandas as pd
import numpy as np
import datetime

# Create a dataframe with the relevant information
df = pd.read_csv('C:/Users/User/Documents/GitHub/ML-Challenges/MachineLearning/data/flights.csv')
df = df[['flight_id', 'airline', 'flight_number', 'origin', 'destination', 'departure_time']]
df = df[df['origin'] == 'London']
df = df[df['destination'] == 'Paris']

# Create a new column for the departure time in hours
df['departure_time_hours'] = df['departure_time'].apply(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d %H:%M:%S').hour)

# Group the dataframe by airline and departure time in hours and count the number of flights
df = df.groupby(['airline', 'departure_time_hours']).size().reset_index(name='count')

# Sort the dataframe by airline and departure time in hours
df = df.sort_values(['airline', 'departure_time_hours'])

# Print the names of the airlines that fly from London to Paris
print(df['airline'].unique())
<|