## This notebook will work on converting text to SQL queries that can be run on Yelp dataset using an SQL agent

### Install Dependencies

In [1]:
%pip --quiet install langchain psycopg2 jq faiss-gpu transformers openai tiktoken python-environ

Note: you may need to restart the kernel to use updated packages.


### Import libraries

In [2]:
import os
import json
import environ
from typing import List
from langchain_community.llms import OpenAI
from langchain.schema import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.sql_database import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.document_loaders import JSONLoader
from langchain.vectorstores import FAISS
from langchain.agents.agent_toolkits import SQLDatabaseToolkit

from config import DATABASE_NAME, USERNAME, PASSWORD, TABLES

### Database connection string

In [3]:
HOST = "localhost"
PORT = "5432"
pg_uri = f"postgresql+psycopg2://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DATABASE_NAME}"

### Load the LLM

In [4]:
env = environ.Env()
environ.Env.read_env()
API_KEY = env('OPENAI_API_KEY')

llm = OpenAI(openai_api_key=API_KEY, temperature=0.0)

  warn_deprecated(


### Initialize the embeddings class

In [5]:
embeddings = OpenAIEmbeddings()

  warn_deprecated(


### Prepare schema documents

In [6]:
tables_document = JSONLoader(file_path='./schemas/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
columns_document = JSONLoader(file_path='./schemas/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()

### Helper method for retrieving matched tables from vector store

In [7]:
def get_matched_tables(query: str) -> List[Document]:
    db = FAISS.from_documents(documents=tables_document, embedding=embeddings)
    retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})
    matched_documents = retriever.get_relevant_documents(query=query)

    return matched_documents

### Helper method for retrieving matched columns from vector store

In [8]:
def get_matched_columns(query: str) -> List[Document]:
    db = FAISS.from_documents(documents=columns_document, embedding=embeddings)
    search_kwargs = {
        'k': 20
    }

    retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)
    matched_columns = retriever.get_relevant_documents(query=query)

    return matched_columns

### Construct the system prompt

In [9]:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct Postgres query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer."""

### Build the agent

In [12]:
def find_tables_to_include(query: str) -> tuple:
    table_info = {}
    tables = set()
    matched_columns = get_matched_columns(query)
    for i in matched_columns:
        # print(i)
        content = json.loads(i.page_content)
        tables.add(content['table_name'])
        if content['table_name'] == 'businesses':
            if 'businesses' in table_info.keys():
                schema = f"""{table_info['businesses']}
                            {content['column_name']}: {content['description']} {content['usage']}
                            """
                table_info['businesses'] = schema
            else:
                table_info['businesses'] = f"""{TABLES[0]}
                                        {content['column_name']}: {content['description']} {content['usage']}
                                        """
        
        elif content['table_name'] == 'users':
            if 'users' in table_info.keys():
                schema = f"""{table_info['users']}
                            {content['column_name']}: {content['description']} {content['usage']}
                            """
                table_info['users'] = schema
            else:
                table_info['users'] = f"""{TABLES[1]}
                                        {content['column_name']}: {content['description']} {content['usage']}
                                        """
                
        elif content['table_name'] == 'reviews':
            if 'reviews' in table_info.keys():
                schema = f"""{table_info['reviews']}
                            {content['column_name']}: {content['description']} {content['usage']}
                            """
                table_info['reviews'] = schema
            else:
                table_info['reviews'] = f"""{TABLES[2]}
                                        {content['column_name']}: {content['description']} {content['usage']}
                                        """
                
        elif content['table_name'] == 'checkins':
            if 'checkins' in table_info.keys():
                schema = f"""{table_info['checkins']}
                            {content['column_name']}: {content['description']} {content['usage']}
                            """
                table_info['checkins'] = schema
            else:
                table_info['checkins'] = f"""{TABLES[3]}
                                        {content['column_name']}: {content['description']} {content['usage']}
                                        """
                
        elif content['table_name'] == 'tips':
            if 'tips' in table_info.keys():
                schema = f"""{table_info['tips']}
                            {content['column_name']}: {content['description']} {content['usage']}
                            """
                table_info['tips'] = schema
            else:
                table_info['tips'] = f"""{TABLES[4]}
                                        {content['column_name']}: {content['description']} {content['usage']}
                                        """

    # print(table_info)
    return list(tables), table_info

def execute_query(query: str):
    tables_to_include, table_info = find_tables_to_include(query)
    print(tables_to_include)
    
    db = SQLDatabase.from_uri(pg_uri, include_tables=tables_to_include, custom_table_info=table_info)
    toolkit = SQLDatabaseToolkit(db=db, llm=llm)

    agent_executor = create_sql_agent(llm=llm, 
                                  toolkit=toolkit,  
                                  prefix=system_prefix,
                                  return_intermediate_steps=True, 
                                  verbose=True)

    return agent_executor.invoke(query)

In [14]:
execute_query("Find me the user who has posted the coolest review?")

['reviews', 'tips', 'businesses', 'users']


[1m> Entering new AgentExecutor chain...[0m




[32;1m[1;3m We need to find the table that contains user information and review information, and then join them together to find the coolest review.
Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mbusinesses, reviews, tips, users[0m[32;1m[1;3mWe need to find the table that contains user information and review information, and then join them together to find the coolest review.
Action: sql_db_schema
Action Input: users, reviews[0m[33;1m[1;3m
    CREATE TABLE reviews (
        review_id VARCHAR(22) PRIMARY KEY NOT NULL, 
        user_id VARCHAR(22) NOT NULL REFERENCES users (user_id), 
        business_id VARCHAR(22) NOT NULL REFERENCES businesses (business_id), 
        stars INT NOT NULL, 
        date DATE NOT NULL, 
        text TEXT NOT NULL, 
        useful INT, 
        funny INT, 
        cool INT
    )
    
                                        cool: Number of users who found the review cool or the number of cool votes received by this review. Displayin

{'input': 'Find me the user who has posted the coolest review?',
 'output': "The user who has posted the coolest review is 'Ormond'."}

In [90]:
execute_query("Which user from Boston checked in into a spa?")





[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m I need to query the user and check-in tables to find the necessary information.
Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mbusinesses, checkins, reviews, tips, users[0m[32;1m[1;3mI need to use the user and check-in tables to get the relevant data.
Action: sql_db_schema
Action Input: user, checkin[0m[33;1m[1;3mError: table_names {'checkin', 'user'} not found in database[0m[32;1m[1;3mI need to double check if the table names are correct using sql_db_query_checker.
Action: sql_db_query_checker
Action Input: SELECT * FROM checkins[0m[36;1m[1;3m
SELECT * FROM checkins
--No common mistakes in this query[0m[32;1m[1;3m I need to add the necessary columns and conditions to the query.
Action: sql_db_query
Action Input: SELECT * FROM checkins WHERE city = 'Boston' AND category = 'spa'[0m[36;1m[1;3mError: (psycopg2.errors.UndefinedColumn) column "city" does not exist
LINE 1: SELECT * FROM checkins 

{'input': 'Which user from Boston checked in into a spa?',
 'output': 'Agent stopped due to iteration limit or time limit.'}

In [91]:
execute_query("Find me the business with the funniest review in San Francisco and get me the user who posted the review as well")





[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m We should start by looking at the reviews and finding the ones with the highest ratings, then filter by location and business type.
Action: sql_db_query
Action Input: SELECT business_name, user_name, review_rating FROM business_reviews WHERE review_rating = 5 AND city = 'San Francisco' AND business_type = 'funny'[0m[36;1m[1;3mError: (psycopg2.errors.UndefinedTable) relation "business_reviews" does not exist
LINE 1: ...LECT business_name, user_name, review_rating FROM business_r...
                                                             ^

[SQL: SELECT business_name, user_name, review_rating FROM business_reviews WHERE review_rating = 5 AND city = 'San Francisco' AND business_type = 'funny']
(Background on this error at: https://sqlalche.me/e/20/f405)[0m[32;1m[1;3mWe should check if the table name is correct and if the columns exist in the table.
Action: sql_db_schema
Action Input: business_reviews[0m[33;1m[1;3m

{'input': 'Find me the business with the funniest review in San Francisco and get me the user who posted the review as well',
 'output': "SELECT business_name, user_name, review_rating FROM business_reviews WHERE review_rating = 5 AND city = 'San Francisco' AND business_type = 'funny'"}

In [11]:
execute_query("find me the business with the highest stars")





[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m I need to find the table that contains business information and the table that contains review information. Then I can join the two tables and sort by the stars column.
Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mbusinesses, reviews, tips, users[0m[32;1m[1;3mI need to use the businesses and reviews tables.
Action: sql_db_schema
Action Input: businesses, reviews[0m[33;1m[1;3m
    CREATE TABLE businesses (
        business_id VARCHAR(22) PRIMARY KEY NOT NULL, 
        name TEXT NOT NULL, 
        address TEXT NOT NULL, 
        city TEXT NOT NULL, 
        state TEXT NOT NULL, 
        postal_code TEXT NOT NULL, 
        latitude DOUBLE PRECISION, 
        longitude DOUBLE PRECISION, 
        stars FLOAT, 
        review_count INT, 
        is_open BOOLEAN NOT NULL, 
        categories TEXT[], 
        monday_hours TEXT, 
        tuesday_hours TEXT, 
        wednesday_hours TEXT, 
        thursday_hou