In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
!pip install langchain langchain_community  langchain-google-vertexai

Collecting langchain
  Downloading langchain-0.2.6-py3-none-any.whl (975 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m975.5/975.5 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain_community
  Downloading langchain_community-0.2.6-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-google-vertexai
  Downloading langchain_google_vertexai-1.0.6-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.0/73.0 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting langchain-core<0.3.0,>=0.2.10 (from langchain)
  Downloading langchain_core-0.2.11-py3-none-any.whl (337 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.4/337.4 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text

In [None]:
from vertexai.language_models import CodeGenerationModel
from google.cloud import bigquery
import logging
import os

In [None]:
PROJECT_ID = 'pradeep-genai'
CODE_GEN_MODEL_NAME = 'code-bison'
TEMPERATURE = 0 # default value = 0
MAX_OUTPUT_TOKENS = 2048  # length of the output response | overridding the default value which is 128
# TOP_P = 0.95  # default value
# TOP_K = 40  # default value
LOCATION = 'us-central1'

In [None]:
DATASET = 'flight_reservations'
TABLES = ['customers', 'flights', 'reservations', 'transactions', 'loyality_points']

In [None]:
import vertexai

In [None]:
bq_client = bigquery.Client(project=PROJECT_ID)
vertexai.init(project=PROJECT_ID, location='us-central1')
code_gen_model = CodeGenerationModel.from_pretrained(model_name=CODE_GEN_MODEL_NAME)


In [None]:
query = f"""
    SELECT *
    FROM `{PROJECT_ID}.{DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in TABLES])})
"""
print(query)


    SELECT *
    FROM `pradeep-genai.flight_reservations.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ("customers","flights","reservations","transactions","loyality_points")



In [None]:
schema_columns = bq_client.query(query=query).to_dataframe()
schema_columns

Unnamed: 0,table_catalog,table_schema,table_name,column_name,field_path,data_type,description,collation_name,rounding_mode
0,pradeep-genai,flight_reservations,transactions,transaction_id,transaction_id,INT64,,,
1,pradeep-genai,flight_reservations,transactions,reservation_id,reservation_id,INT64,,,
2,pradeep-genai,flight_reservations,transactions,amount,amount,FLOAT64,,,
3,pradeep-genai,flight_reservations,transactions,transaction_datetime,transaction_datetime,DATETIME,,,
4,pradeep-genai,flight_reservations,reservations,reservation_id,reservation_id,INT64,,,
5,pradeep-genai,flight_reservations,reservations,customer_id,customer_id,INT64,,,
6,pradeep-genai,flight_reservations,reservations,flight_id,flight_id,INT64,,,
7,pradeep-genai,flight_reservations,reservations,reservation_datetime,reservation_datetime,DATETIME,,,
8,pradeep-genai,flight_reservations,reservations,status,status,STRING,,,
9,pradeep-genai,flight_reservations,flights,flight_id,flight_id,INT64,,,


In [None]:
schema_columns = schema_columns.to_markdown(index=False)
print(schema_columns)

| table_catalog   | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| pradeep-genai   | flight_reservations | transactions | transaction_id       | transaction_id       | INT64       |               | NULL             |                 |
| pradeep-genai   | flight_reservations | transactions | reservation_id       | reservation_id       | INT64       |               | NULL             |                 |
| pradeep-genai   | flight_reservations | transactions | amount               | amount               | FLOAT64     |               | NULL             |                 |
| pradeep-genai   | flight_reservations | transactions | transaction_datetime | transaction_datetime | DATETIME    |               | NULL             

In [None]:
def generate_and_execute_sql(prompt, max_tries=15):
    """
    Generate an SQL query using the code_gen_model and execute it using bq_client.

    Args:
    - prompt (str): Prompt to provide to the model for generating SQL.
    - max_tries (int): Maximum number of attempts to generate and execute SQL.

    Returns:
    - dict: A dictionary containing the successful dataframe or error messages and prompt evolution.
    """

    tries = 0
    error_messages = []
    prompts = [prompt]
    df = None

    while tries < max_tries:
        print(f'ATTEMPT: {tries+1}')
        try:
            # Predict SQL using the model
            response = code_gen_model.predict(prompt, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS)
            generated_sql_query = response.text
            generated_sql_query = '\n'.join(generated_sql_query.split('\n')[1:-1])
            print('-' * 50)
            print(generated_sql_query)
            print('-' * 50)
            # Execute SQL using BigQuery client
            df = bq_client.query(generated_sql_query).to_dataframe()
            print('SUCCEEDED')
            return {"dataframe": df, "prompts": prompts, "errors": error_messages}
        except Exception as e:
            print('FAILED')
            # Catch the error, store the message, and try again
            msg = str(e)
            error_messages.append(msg)
            # Evolve the prompt by appending the error message and asking the model to correct it
            prompt = f"""{prompt}
Encountered an error: {msg}.
To address this, please generate an alternative SQL query response that avoids this specific error.
Follow the instructions mentioned above to remediate the error.

Modify the below SQL query to resolve the issue:
{generated_sql_query}

Ensure the revised SQL query aligns precisely with the requirements outlined in the initial question."""
            prompts.append(prompt)
            tries += 1
        print('=' * 100)

    return {"dataframe": df, "prompts": prompts, "errors": error_messages}

In [None]:
seed_prompt = """
Please craft a SQL query for BigQuery that addresses the following QUESTION provided below.
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below.
When joining tables, employ type coercion to guarantee data type consistency for the join columns.
Additionally, the output column names should specify units where applicable.\n
QUESTION:
{}\n
SCHEMA:
{}\n
IMPORTANT:
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.
"""
print(seed_prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
{}

SCHEMA:
{}

IMPORTANT: 
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.



For this scenario, you want to find all active reservations within a specific date range.

In [None]:
question = "Provide a list of all reservations from October 10th to October 15th, 2023"

In [None]:
prompt = seed_prompt.format(question, schema_columns)
print(prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
Provide a list of all reservations from October 10th to October 15th, 2023

SCHEMA:
| table_catalog   | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| pradeep-genai   | flight_reservations | transactions | transaction_id       | transaction_id       | INT64       |               | NULL             |                 |
| pradeep-genai   | fli

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT
  r.reservation_id AS Reservation_ID,
  r.reservation_datetime AS Reservation_DateTime,
  r.status AS Reservation_Status,
  f.flight_id AS Flight_ID,
  f.origin AS Flight_Origin,
  f.destination AS Flight_Destination,
  f.departure_datetime AS Flight_Departure_DateTime,
  f.arrival_datetime AS Flight_Arrival_DateTime,
  f.carrier AS Flight_Carrier,
  f.price AS Flight_Price,
  c.customer_id AS Customer_ID,
  c.first_name AS Customer_FirstName,
  c.last_name AS Customer_LastName,
  c.email AS Customer_Email
FROM
  reservations r
JOIN
  flights f ON r.flight_id = f.flight_id
JOIN
  customers c ON r.customer_id = c.customer_id
WHERE
  r.reservation_datetime BETWEEN '2023-10-10 00:00:00' AND '2023-10-15 23:59:59';
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
SELECT
  r.reservation_id AS Reservation_ID,
  r.reservation_datetime AS Reservation_DateTi

Unnamed: 0,Reservation_ID,Reservation_DateTime,Reservation_Status,Flight_ID,Flight_Origin,Flight_Destination,Flight_Departure_DateTime,Flight_Arrival_DateTime,Flight_Carrier,Flight_Price,Customer_ID,Customer_FirstName,Customer_LastName,Customer_Email
0,6,2023-10-10 10:00:00,Confirmed,6,SEA,JFK,2023-11-25 06:00:00,2023-11-25 14:30:00,United,550.0,6,Diana,Prince,diana.p@example.com
1,7,2023-10-12 11:30:00,Confirmed,7,JFK,MIA,2023-11-27 20:00:00,2023-11-27 23:30:00,American,380.0,6,Diana,Prince,diana.p@example.com
2,8,2023-10-15 13:20:00,Confirmed,8,MIA,JFK,2023-11-30 10:00:00,2023-11-30 13:30:00,American,380.0,8,Fiona,Shrek,fiona.s@example.com


In [None]:
question = "Identify all customers who have made flight reservations within the last 7 days."

In [None]:
prompt = seed_prompt.format(question, schema_columns)

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT DISTINCT c.customer_id, c.first_name, c.last_name, c.email
FROM customers c
JOIN reservations r ON c.customer_id = r.customer_id
WHERE r.reservation_datetime BETWEEN DATE_SUB(CURRENT_DATETIME(), INTERVAL 7 DAY) AND CURRENT_DATETIME();
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
SELECT DISTINCT c.customer_id, c.first_name, c.last_name, c.email
FROM flight_reservations.customers c
JOIN flight_reservations.reservations r ON c.customer_id = r.customer_id
WHERE r.reservation_datetime BETWEEN DATE_SUB(CURRENT_DATETIME(), INTERVAL 7 DAY) AND CURRENT_DATETIME();
--------------------------------------------------
SUCCEEDED
CPU times: user 75 ms, sys: 8.49 ms, total: 83.4 ms
Wall time: 4.55 s


Unnamed: 0,customer_id,first_name,last_name,email


In [None]:
question = "Calculate the total revenue generated from transactions in October 2023, specifically from all reservations with a Confirmed status."

In [None]:
prompt = seed_prompt.format(question, schema_columns)

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT SUM(t.amount) AS total_revenue_usd
FROM transactions t
JOIN reservations r ON t.reservation_id = r.reservation_id
WHERE r.status = 'Confirmed'
AND DATE(t.transaction_datetime) BETWEEN '2023-10-01' AND '2023-10-31';
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
SELECT SUM(t.amount) AS total_revenue_usd
FROM transactions t
JOIN reservations r ON CAST(t.reservation_id AS STRING) = CAST(r.reservation_id AS STRING)
WHERE r.status = 'Confirmed'
AND DATE(t.transaction_datetime) BETWEEN '2023-10-01' AND '2023-10-31';
--------------------------------------------------
FAILED
ATTEMPT: 3
--------------------------------------------------
SELECT SUM(t.amount) AS total_revenue_usd
FROM transactions t
JOIN reservations r ON t.reservation_id = r.reservation_id
WHERE r.status = 'Confirmed'
AND DATE(t.transaction_datetime) BETWEEN '2023-10-01' AND '2023-10-31';


Unnamed: 0,total_revenue_usd
0,3860.0


In [None]:
question = "Determine the departure months with the highest frequency for the year 2023."

In [None]:
prompt = seed_prompt.format(question, schema_columns)

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT DATE_TRUNC(departure_datetime, MONTH) AS departure_month,
       COUNT(*) AS num_departures
FROM flight_reservations.flights
WHERE departure_datetime BETWEEN '2023-01-01' AND '2023-12-31'
GROUP BY departure_month
ORDER BY num_departures DESC
LIMIT 10;
--------------------------------------------------
SUCCEEDED


CPU times: user 63.7 ms, sys: 10.9 ms, total: 74.6 ms
Wall time: 4.69 s


Unnamed: 0,departure_month,num_departures
0,2023-11-01,8
1,2023-12-01,7


In [None]:
question = "Group customers into five distinct age brackets and count the number of customers in each bracket."

In [None]:
prompt = seed_prompt.format(question, schema_columns)

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
-- Calculate customer age brackets and count customers within each bracket
SELECT
  CASE
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) < 20 THEN '0-19 years'
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) BETWEEN 20 AND 29 THEN '20-29 years'
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) BETWEEN 30 AND 39 THEN '30-39 years'
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) BETWEEN 40 AND 49 THEN '40-49 years'
    ELSE '50+ years'
  END AS age_bracket,
  COUNT(c.customer_id) AS customer_count
FROM
  customers c
GROUP BY
  age_bracket;
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
-- Calculate customer age brackets and count customers within each bracket
SELECT
  CASE
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) < 20 THEN '0-19 years'
    WHEN DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) BETWEEN 

Unnamed: 0,age_bracket,customer_count
0,30-39 years,6
1,40-49 years,5
2,20-29 years,3
3,50+ years,6


In [None]:
question = "Identify and rank all customers aged 18+ who have `Confirmed` reservations for the current month, ordered by their age. Make sure to display their ages in the result."

In [None]:
prompt = seed_prompt.format(question, schema_columns)

In [None]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
WITH CurrentMonthConfirmedReservations AS (
    SELECT
        t.customer_id
    FROM
        transactions t
    JOIN
        reservations r ON t.reservation_id = r.reservation_id
    WHERE
        r.status = 'Confirmed'
        AND DATE(t.transaction_datetime) BETWEEN DATE_TRUNC(DATE(CURRENT_DATE()), MONTH) AND DATE_ADD(DATE_TRUNC(DATE(CURRENT_DATE()), MONTH), INTERVAL 1 MONTH) - INTERVAL 1 DAY
)

SELECT
    c.customer_id,
    c.first_name,
    c.last_name,
    DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) AS age,
    RANK() OVER (ORDER BY DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR)) AS age_rank
FROM
    customers c
JOIN
    CurrentMonthConfirmedReservations cmcr ON c.customer_id = cmcr.customer_id
WHERE
    DATE_DIFF(CURRENT_DATE(), c.date_of_birth, YEAR) >= 18;
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
WITH CurrentMonthConfirmedReservati

Inspect the evolution of the seed prompt and how SQL query was fixed automatically by the LLM.

In [None]:
for i, prompt in enumerate(response['prompts']):
    print(f'==================== ATTEMPT {i+1} ====================')
    print(prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
Identify and rank all customers aged 18+ who have `Confirmed` reservations for the current month, ordered by their age. Make sure to display their ages in the result.

SCHEMA:
| table_catalog   | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| pradeep-genai   | flight_reservations | transactions | transaction_id       | transaction_id       