In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
!pip install langchain langchain_community langchain-google-genai langchain-google-vertexai jq faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: faiss-cpu
Successfully installed faiss-cpu-1.8.0.post1


In [None]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain_google_vertexai import ChatVertexAI
from langchain.vectorstores import FAISS
from google.cloud import bigquery
from typing import List
from tqdm import tqdm
import logging
import json
import os

In [None]:
PROJECT = 'pradeep-genai'
LOCATION = 'us-central1'
MODEL_NAME = 'codechat-bison@latest'

In [None]:
llm = ChatVertexAI(project=PROJECT,
                   location=LOCATION,
                   model_name="codechat-bison",
                   temperature=0.0,
                   max_output_tokens=512)

In [None]:
class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):
    model_name = 'textembedding-gecko'
    max_batch_size = 5

    def embed_segments(self, segments: List) -> List:
        embeddings = []
        for i in tqdm(range(0, len(segments), self.max_batch_size)):
            batch = segments[i: i+self.max_batch_size]
            embeddings.extend(self.client.get_embeddings(batch))
        return [embedding.values for embedding in embeddings]

    def embed_query(self, query: str) -> List:
        embeddings = self.client.get_embeddings([query])
        return embeddings[0].values

In [None]:
embedding = MyVertexAIEmbeddings()



In [None]:
query = "Provide a list of all flight reservations from October 10th to October 15th, 2023"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
documents = JSONLoader(file_path='/content/drive/MyDrive/SQL Generation/data/rag-schema/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()

In [None]:
db = FAISS.from_documents(documents=documents, embedding=embedding)

In [None]:
retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})

In [None]:

matched_documents = retriever.get_relevant_documents(query=query)


  warn_deprecated(


In [None]:
matched_tables = []

for document in matched_documents:
    page_content = document.page_content
    page_content = json.loads(page_content)
    dataset_name = page_content['dataset_name']
    table_name = page_content['table_name']
    matched_tables.append(f'{dataset_name}.{table_name}')

print(f'Matched tables = {matched_tables}')


Matched tables = ['flight_reservations.reservations', 'flight_reservations.transactions', 'flight_reservations.flights', 'hotel_reservations.reservations', 'hotel_reservations.inventory']


In [None]:
documents = JSONLoader(file_path='/content/drive/MyDrive/SQL Generation/data/rag-schema/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
db = FAISS.from_documents(documents=documents, embedding=embedding)

In [None]:
documents[0]

Document(metadata={'source': '/content/drive/MyDrive/SQL Generation/data/rag-schema/columns.jsonl', 'seq_num': 1}, page_content='{"dataset_name": "hotel_reservations", "table_name": "hotels", "column_name": "hotel_id", "description": "A unique identifier assigned to each hotel.", "usage": "This ID helps in maintaining a distinct record for each hotel and acts as a primary key. It\'s also used for referencing in other tables like Rooms.", "data_type": "INT64"}')

In [None]:
search_kwargs = {
    'k': 20
}

retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)

In [None]:
matched_columns = retriever.get_relevant_documents(query=query)
print(matched_columns)

[Document(metadata={'source': '/content/drive/MyDrive/SQL Generation/data/rag-schema/columns.jsonl', 'seq_num': 62}, page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "column_name": "reservation_datetime", "description": "Timestamp of when the reservation was made.", "usage": "Helps track reservation history and manage bookings.", "data_type": "DATETIME"}'), Document(metadata={'source': '/content/drive/MyDrive/SQL Generation/data/rag-schema/columns.jsonl', 'seq_num': 55}, page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "departure_datetime", "description": "The departure time of the flight.", "usage": "Informs users and helps them plan their travel.", "data_type": "DATETIME"}'), Document(metadata={'source': '/content/drive/MyDrive/SQL Generation/data/rag-schema/columns.jsonl', 'seq_num': 53}, page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "origin", "description"

In [None]:
matched_columns_filtered = []


for i, column in enumerate(matched_columns):
    page_content = json.loads(column.page_content)
    dataset_name = page_content['dataset_name']
    if dataset_name == 'flight_reservations':
        matched_columns_filtered.append(page_content)
print(matched_columns_filtered)


[{'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'reservation_datetime', 'description': 'Timestamp of when the reservation was made.', 'usage': 'Helps track reservation history and manage bookings.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'departure_datetime', 'description': 'The departure time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'origin', 'description': 'The departure location of the flight.', 'usage': 'Helps users find flights based on their travel plans.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'destination', 'description': 'The arrival location of the flight.', 'usage': 'Used to find flights and plan journeys.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservation

In [None]:
matched_columns_cleaned = []

for doc in matched_columns_filtered:
    dataset_name = doc['dataset_name']
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    matched_columns_cleaned.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')

matched_columns_cleaned = '\n'.join(matched_columns_cleaned)
print(matched_columns_cleaned)

dataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=transaction_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=customers|column_name=created_at|data_type=DATETIME
dataset_name=flight_reservations|table_name=customers|column_name=date_of_birth|data_type=DATE
dataset_name=flight_reservations|table_name=reservations|column_name=status|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=reservation_id|data_type=INT64
dataset_name=fl

In [None]:
messages = []

In [None]:
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [None]:
human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA:
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above.
The goal is to determine the availability of hotels based on the provided info.

Higly IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this.
Higly IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA. such as dataset_name.table_name
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
"""

In [None]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [None]:
chat_prompt = ChatPromptTemplate.from_messages(messages)

In [None]:
request = chat_prompt.format_prompt(query=query,
                                    matched_schema=matched_columns_cleaned).to_messages()

In [None]:
print(request)

[SystemMessage(content='You are a SQL master expert capable of writing complex SQL query in BigQuery.'), HumanMessage(content="Given the following inputs:\nUSER_QUERY:\n--\nProvide a list of all flight reservations from October 10th to October 15th, 2023\n--\nMATCHED_SCHEMA: \n--\ndataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME\ndataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME\ndataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING\ndataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING\ndataset_name=flight_reservations|table_name=transactions|column_name=transaction_datetime|data_type=DATETIME\ndataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME\ndataset_name=flight_reservations|table_name=customers|column_name=created_at|data_type=DATETIME\ndataset_name=f

In [None]:
%%time

response = llm(request)
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
print(sql)

SELECT DISTINCT
  r.reservation_id,
  r.reservation_datetime,
  f.departure_datetime,
  f.origin,
  f.destination,
  f.arrival_datetime,
  f.carrier
FROM
  flight_reservations.reservations AS r
JOIN
  flight_reservations.flights AS f
ON
  r.flight_id = f.flight_id
WHERE
  r.reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15';
CPU times: user 19.5 ms, sys: 0 ns, total: 19.5 ms
Wall time: 1.08 s


In [None]:
bq = bigquery.Client(project=PROJECT)

In [None]:
df = bq.query(sql).to_dataframe()
df

Unnamed: 0,reservation_id,reservation_datetime,departure_datetime,origin,destination,arrival_datetime,carrier
0,6,2023-10-10 10:00:00,2023-11-25 06:00:00,SEA,JFK,2023-11-25 14:30:00,United
1,7,2023-10-12 11:30:00,2023-11-27 20:00:00,JFK,MIA,2023-11-27 23:30:00,American
