## This notebook will work on converting text to SQL queries that can be run on Yelp dataset using Google's Code Bison LLM

### Install Dependencies

In [21]:
%pip --quiet install langchain tqdm chromadb==0.3.29 psycopg2 google-cloud-aiplatform==1.38.0 jq faiss-gpu

Note: you may need to restart the kernel to use updated packages.


### Import Libraries

In [36]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatVertexAI
from langchain.vectorstores import FAISS
from typing import List
from tqdm import tqdm
import chromadb
import logging
import json
import os
import psycopg2

from config import *

### Setup logging

In [6]:
logger = logging.getLogger('langchain')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

### Google Cloud Config

In [14]:
SERVICE_ACCOUNT_KEY_PATH = 'llm-study-413709-40a30207144b.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH

PROJECT = 'llm-study-413709'
LOCATION = 'us-central1'
MODEL_NAME = 'codechat-bison'

### Create LLM

In [32]:
llm = ChatVertexAI(project=PROJECT, 
                   location=LOCATION, 
                   model_name=MODEL_NAME,
                   temperature=0.0, 
                   max_output_tokens=2048)

### Create custom embeddings class

In [16]:
class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):
    model_name = 'textembedding-gecko'
    max_batch_size = 5
    
    def embed_segments(self, segments: List) -> List:
        embeddings = []
        for i in tqdm(range(0, len(segments), self.max_batch_size)):
            batch = segments[i: i+self.max_batch_size]
            embeddings.extend(self.client.get_embeddings(batch))
        return [embedding.values for embedding in embeddings]
    
    def embed_query(self, query: str) -> List:
        embeddings = self.client.get_embeddings([query])
        return embeddings[0].values
    
embeddings = MyVertexAIEmbeddings()

Model_name will become a required arg for VertexAIEmbeddings starting from Feb-01-2024. Currently the default is set to textembedding-gecko@001


### Prepare schema documents

In [25]:
tables_document = JSONLoader(file_path='./schemas/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
columns_document = JSONLoader(file_path='./schemas/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
tables_document
print()
columns_document




[Document(page_content='{"dataset_name": "yelp", "table_name": "businesses", "column_name": "business_id", "description": "A unique identifier assigned to each business.", "usage": "This ID helps in maintaining a distinct record for each business and acts as a primary key. It\'s also used for referencing in other tables like reviews.", "data_type": "VARCHAR"}', metadata={'source': '/home/prasad/Desktop/work/yelp_rag/schemas/columns.jsonl', 'seq_num': 1}),
 Document(page_content='{"dataset_name": "yelp", "table_name": "businesses", "column_name": "name", "description": "The name of the business.", "usage": "This column provides users with the name of the business. It aids in branding and recognition.", "data_type": "TEXT"}', metadata={'source': '/home/prasad/Desktop/work/yelp_rag/schemas/columns.jsonl', 'seq_num': 2}),
 Document(page_content='{"dataset_name": "yelp", "table_name": "businesses", "column_name": "address", "description": "The address of the business.", "usage": "This colum

### Helper method for retrieving matched tables from vector store

In [26]:
def get_matched_tables(query: str) -> List:
    db = FAISS.from_documents(documents=tables_document, embedding=embeddings)
    retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})
    matched_documents = retriever.get_relevant_documents(query=query)

    return matched_documents

### Helper method for retrieving matched columns from vector store

In [27]:
def get_matched_columns(query: str) -> List:
    db = FAISS.from_documents(documents=columns_document, embedding=embeddings)
    search_kwargs = {
        'k': 20
    }

    retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)
    matched_columns = retriever.get_relevant_documents(query=query)

    return matched_columns

### Text to SQL generation

In [42]:
messages = []
template = "You are a SQL master expert capable of writing complex SQL query in Postgres."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA: 
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above. 

IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. 
IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. 
"""
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

chat_prompt = ChatPromptTemplate.from_messages(messages)

user_question = 'Find me a business with the funniest review and show me the review as well.'
matched_columns = get_matched_columns(user_question)

request = chat_prompt.format_prompt(query=user_question,
    matched_schema=matched_columns).to_messages()

response = llm(request)
query = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(query)

with psycopg2.connect(f'dbname={DATABASE_NAME} user={USERNAME} password={PASSWORD}') as conn:
    with conn.cursor() as cursor:
        cursor.execute(query)
        records = cursor.fetchall()
        logger.info(records)

SELECT
    B.name AS business_name,
    R.text AS review_text,
    R.funny AS funny_votes
FROM
    reviews AS R
JOIN
    businesses AS B
ON
    R.business_id = B.business_id
ORDER BY
    R.funny DESC
LIMIT 1;
[('Broadway Oyster Bar', 'Went there for a birthday dinner and had reservations for 9 people (and 3 additional people showed up) after a grueling 2 hr wait and several drinks we finally got seated only to be told that because we had 3 extra people that we would have to split our party up and they would have to wait for a table! The manager was unwilling to work with us and so we took our party of 12 very hungry ad upset people to another Soulard Restaurant (he-hem) right down the street and had a FABULOUS dinner with EXCEPTIONAL CUSTOMER SERVICE! Like my brother in law said, if he ran his business like BroadwayOyster ran theirs then he would be out of business ! I agree with him that it is our fault as a society for allowing this behavior to be accepted and to allow sub par custom