## Infer schema directly & self-correct

#### Imports 

In [2]:
from vertexai.language_models import CodeGenerationModel
from google.cloud import bigquery
import logging 
import os 

##### Setup logging

In [3]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

#### Setup essentials

In [4]:
SERVICE_ACCOUNT_CREDENTIALS = './../credentials/vai-key.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_CREDENTIALS

In [5]:
PROJECT_ID = 'arun-genai-bb'
CODE_GEN_MODEL_NAME = 'code-bison@latest'
TEMPERATURE = 0 # default value = 0
MAX_OUTPUT_TOKENS = 2048  # length of the output response | overridding the default value which is 128
# TOP_P = 0.95  # default value
# TOP_K = 40  # default value
LOCATION = 'us-central1'

In [6]:
DATASET = 'flight_reservations'
TABLES = ['customers', 'flights', 'reservations', 'transactions', 'loyality_points']

In [7]:
code_gen_model = CodeGenerationModel.from_pretrained(CODE_GEN_MODEL_NAME)
bq_client = bigquery.Client()



GoogleAuthError: 
Unable to authenticate your request.
Depending on your runtime environment, you can complete authentication by:
- if in local JupyterLab instance: `!gcloud auth login` 
- if in Colab:
    -`from google.colab import auth`
    -`auth.authenticate_user()`
- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication

In [8]:
query = f"""
    SELECT *
    FROM `{PROJECT_ID}.{DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in TABLES])})
"""
logger.info(query)


    SELECT *
    FROM `arun-genai-bb.flight_reservations.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ("customers","flights","reservations","transactions","loyality_points")



In [9]:
schema_columns = bq_client.query(query=query).to_dataframe()
schema_columns

NameError: name 'bq_client' is not defined

In [9]:
schema_columns = schema_columns.to_markdown(index=False)
logger.info(schema_columns)

| table_catalog   | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| arun-genai-bb   | flight_reservations | transactions | transaction_id       | transaction_id       | INT64       |               | NULL             |                 |
| arun-genai-bb   | flight_reservations | transactions | reservation_id       | reservation_id       | INT64       |               | NULL             |                 |
| arun-genai-bb   | flight_reservations | transactions | amount               | amount               | FLOAT64     |               | NULL             |                 |
| arun-genai-bb   | flight_reservations | transactions | transaction_datetime | transaction_datetime | DATETIME    |               | NULL             

Utility function to convert text into SQL and automatically correct the SQL query if any execution errors occur.

In [10]:
def generate_and_execute_sql(prompt, max_tries=8):
    """
    Generate an SQL query using the code_gen_model and execute it using bq_client.
    
    Args:
    - prompt (str): Prompt to provide to the model for generating SQL.
    - max_tries (int): Maximum number of attempts to generate and execute SQL.
    
    Returns:
    - dict: A dictionary containing the successful dataframe or error messages and prompt evolution.
    """
    
    tries = 0
    error_messages = []
    prompts = [prompt]
    df = None
    
    while tries < max_tries:
        logger.info(f'ATTEMPT: {tries+1}')
        try:
            # Predict SQL using the model
            response = code_gen_model.predict(prompt, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS)
            generated_sql_query = response.text
            generated_sql_query = '\n'.join(generated_sql_query.split('\n')[1:-1])
            logger.info('-' * 50)
            logger.info(generated_sql_query)
            logger.info('-' * 50)
            # Execute SQL using BigQuery client
            df = bq_client.query(generated_sql_query).to_dataframe()
            logger.info('SUCCEEDED')
            return {"dataframe": df, "prompts": prompts, "errors": error_messages}
        except Exception as e:
            logger.error('FAILED')
            # Catch the error, store the message, and try again
            msg = str(e)
            error_messages.append(msg)
            # Evolve the prompt by appending the error message and asking the model to correct it
            prompt = f"""{prompt}
Encountered an error: {msg}. 
To address this, please generate an alternative SQL query response that avoids this specific error. 
Follow the instructions mentioned above to remediate the error. 

Modify the below SQL query to resolve the issue:
{generated_sql_query}

Ensure the revised SQL query aligns precisely with the requirements outlined in the initial question."""
            prompts.append(prompt)
            tries += 1
        logger.info('=' * 100)

    return {"dataframe": df, "prompts": prompts, "errors": error_messages}

### Test Text-to-SQL scenarios

Construct the SEED prompt

In [12]:
seed_prompt = """
Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.\n
QUESTION:
{}\n
SCHEMA:
{}\n
IMPORTANT: 
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.
"""
logger.info(seed_prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
{}

SCHEMA:
{}

IMPORTANT: 
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.



#### Scenario 1: Retrieve Active Reservations for a Specific Date Range

For this scenario, you want to find all active reservations within a specific date range.

In [13]:
question = "Provide a list of all reservations from October 10th to October 15th, 2023"

In [15]:
prompt = seed_prompt.format(question, schema_columns)
logger.info(prompt)

NameError: name 'schema_columns' is not defined

In [14]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT
  r.reservation_id,
  r.customer_id,
  r.flight_id,
  r.reservation_datetime AS reservation_datetime_utc,
  f.origin,
  f.destination,
  f.departure_datetime AS departure_datetime_utc,
  f.arrival_datetime AS arrival_datetime_utc,
  f.carrier,
  f.price
FROM arun-genai-bb.flight_reservations.reservations AS r
JOIN arun-genai-bb.flight_reservations.flights AS f
ON CAST(r.flight_id AS STRING) = f.flight_id
WHERE r.reservation_datetime >= '2023-10-10 00:00:00'
AND r.reservation_datetime < '2023-10-16 00:00:00';
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
SELECT
  r.reservation_id,
  r.customer_id,
  r.flight_id,
  r.reservation_datetime AS reservation_datetime_utc,
  f.origin,
  f.destination,
  f.departure_datetime AS departure_datetime_utc,
  f.arrival_datetime AS arrival_datetime_utc,
  f.carrier,
  f.price
FROM arun-genai-bb.flight_reservatio

CPU times: user 115 ms, sys: 19.5 ms, total: 135 ms
Wall time: 9.64 s


Unnamed: 0,reservation_id,customer_id,flight_id,reservation_datetime_utc,origin,destination,departure_datetime_utc,arrival_datetime_utc,carrier,price
0,6,6,6,2023-10-10 10:00:00,SEA,JFK,2023-11-25 06:00:00,2023-11-25 14:30:00,United,550.0
1,7,6,7,2023-10-12 11:30:00,JFK,MIA,2023-11-27 20:00:00,2023-11-27 23:30:00,American,380.0
2,8,8,8,2023-10-15 13:20:00,MIA,JFK,2023-11-30 10:00:00,2023-11-30 13:30:00,American,380.0


#### Scenario 2: Identify customers who made reservations in the past N days.

In [15]:
question = "Identify all customers who have made flight reservations within the last 7 days."

In [16]:
prompt = seed_prompt.format(question, schema_columns)

In [17]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT c.customer_id,
       c.first_name,
       c.last_name,
       c.email,
       t.transaction_datetime AS reservation_date
FROM flight_reservations.customers c
JOIN flight_reservations.reservations r
ON c.customer_id = r.customer_id
JOIN flight_reservations.transactions t
ON r.reservation_id = t.reservation_id
WHERE t.transaction_datetime >= DATE_SUB(CURRENT_DATETIME(), INTERVAL 7 DAY)
--------------------------------------------------
SUCCEEDED


CPU times: user 50.3 ms, sys: 9.21 ms, total: 59.5 ms
Wall time: 4.86 s


Unnamed: 0,customer_id,first_name,last_name,email,reservation_date
0,6,Diana,Prince,diana.p@example.com,2023-10-12 11:31:00
1,8,Fiona,Shrek,fiona.s@example.com,2023-10-15 13:21:00
2,10,Hannah,Montana,hannah.m@example.com,2023-10-22 15:46:00
3,11,Ian,Somerhalder,ian.s@example.com,2023-10-25 12:31:00
4,11,Ian,Somerhalder,ian.s@example.com,2023-10-28 17:11:00
5,13,Kate,Winslet,kate.w@example.com,2023-11-02 08:21:00
6,15,Mary,Jane,mary.j@example.com,2023-11-04 10:46:00
7,16,Nick,Fury,nick.f@example.com,2023-11-08 15:31:00
8,17,Olivia,Newton,olivia.n@example.com,2023-11-11 10:16:00
9,18,Peter,Parker,peter.p@example.com,2023-11-15 12:51:00


### Scenario 3: Calculate Monthly Revenue
Calculate the total revenue generated from transactions for a given month and year.

In [18]:
question = "Calculate the total revenue generated from transactions in October 2023, specifically from all reservations with a Confirmed status."

In [19]:
prompt = seed_prompt.format(question, schema_columns)

In [20]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT
  SUM(transactions.amount) AS total_revenue_usd
FROM
  arun-genai-bb.flight_reservations.transactions AS transactions
JOIN
  arun-genai-bb.flight_reservations.reservations AS reservations
ON
  transactions.reservation_id = reservations.reservation_id
WHERE
  reservations.status = 'Confirmed'
  AND transactions.transaction_datetime BETWEEN '2023-10-01 00:00:00' AND '2023-10-31 23:59:59';
--------------------------------------------------
SUCCEEDED


CPU times: user 60.3 ms, sys: 10.4 ms, total: 70.8 ms
Wall time: 4.44 s


Unnamed: 0,total_revenue_usd
0,3860.0


### Scenario 4: Popular Flight Times
Identify the most popular departure hours or days for a given day or month or year.

In [21]:
question = "Determine the departure months with the highest frequency for the year 2023."

In [22]:
prompt = seed_prompt.format(question, schema_columns)

In [23]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT DATE_TRUNC(departure_datetime, MONTH) AS departure_month,
       COUNT(*) AS num_departures
FROM flight_reservations.flights
WHERE departure_datetime BETWEEN '2023-01-01' AND '2023-12-31'
GROUP BY departure_month
ORDER BY num_departures DESC
LIMIT 10;
--------------------------------------------------
SUCCEEDED


CPU times: user 63.7 ms, sys: 10.9 ms, total: 74.6 ms
Wall time: 4.69 s


Unnamed: 0,departure_month,num_departures
0,2023-11-01,8
1,2023-12-01,7


### Scenario 5: Customer Age Group
Group customers by age brackets and count the number in each bracket.

In [24]:
question = "Group customers into five distinct age brackets and count the number of customers in each bracket."

In [25]:
prompt = seed_prompt.format(question, schema_columns)

In [26]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
SELECT
  CASE
    WHEN date_of_birth IS NULL THEN 'Unknown'
    WHEN date_of_birth >= DATE_SUB(CURRENT_DATE(), INTERVAL 65 YEAR) THEN '65+'
    WHEN date_of_birth >= DATE_SUB(CURRENT_DATE(), INTERVAL 55 YEAR) THEN '55-64'
    WHEN date_of_birth >= DATE_SUB(CURRENT_DATE(), INTERVAL 45 YEAR) THEN '45-54'
    WHEN date_of_birth >= DATE_SUB(CURRENT_DATE(), INTERVAL 35 YEAR) THEN '35-44'
    ELSE '0-34'
  END AS age_bracket,
  COUNT(customer_id) AS num_customers
FROM flight_reservations.customers
GROUP BY age_bracket;
--------------------------------------------------
SUCCEEDED


CPU times: user 60.9 ms, sys: 10.7 ms, total: 71.6 ms
Wall time: 5.62 s


Unnamed: 0,age_bracket,num_customers
0,65+,17
1,0-34,3


### Scenario 6: Age Calculation
Calculate the age of customers based on their date of birth and filter those who are above X years old.

In [31]:
question = "Identify and rank all customers aged 18+ who have `Confirmed` reservations for the current month, ordered by their age. Make sure to display their ages in the result."

In [32]:
prompt = seed_prompt.format(question, schema_columns)

In [33]:
%%time

response = generate_and_execute_sql(prompt=prompt)
sql_output = response['dataframe']
sql_output

ATTEMPT: 1
--------------------------------------------------
WITH current_month AS (
  SELECT DATE_TRUNC(CURRENT_DATE(), MONTH) AS start_of_month,
         DATE_ADD(DATE_TRUNC(CURRENT_DATE(), MONTH), INTERVAL 1 MONTH) AS end_of_month
)

SELECT c.customer_id,
       c.first_name,
       c.last_name,
       DATEDIFF(current_month.end_of_month, c.date_of_birth) / 365 AS age,
       r.reservation_id,
       r.status
FROM current_month
CROSS JOIN flight_reservations.customers c
JOIN flight_reservations.reservations r
ON c.customer_id = r.customer_id
WHERE r.status = 'Confirmed'
AND r.reservation_datetime BETWEEN current_month.start_of_month AND current_month.end_of_month
AND DATEDIFF(current_month.end_of_month, c.date_of_birth) / 365 >= 18
ORDER BY age
--------------------------------------------------
FAILED
ATTEMPT: 2
--------------------------------------------------
  SELECT DATE_TRUNC(CURRENT_DATE(), MONTH) AS start_of_month,
         DATE_ADD(DATE_TRUNC(CURRENT_DATE(), MONTH), INTERV

CPU times: user 86.8 ms, sys: 21.3 ms, total: 108 ms
Wall time: 16 s


Unnamed: 0,customer_id,first_name,last_name,age,reservation_id,status
0,10,Hannah,Montana,25,10,Confirmed
1,8,Fiona,Shrek,31,8,Confirmed
2,3,Alice,Johnson,33,3,Confirmed
3,6,Diana,Prince,35,5,Confirmed
4,6,Diana,Prince,35,6,Confirmed
5,6,Diana,Prince,35,7,Confirmed
6,2,Jane,Doe,36,2,Confirmed
7,1,John,Doe,38,1,Confirmed
8,11,Ian,Somerhalder,45,11,Confirmed
9,11,Ian,Somerhalder,45,12,Confirmed


Inspect the evolution of the seed prompt and how SQL query was fixed automatically by the LLM.

In [34]:
for i, prompt in enumerate(response['prompts']):
    logger.info(f'==================== ATTEMPT {i+1} ====================')
    logger.info(prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
Identify and rank all customers aged 18+ who have `Confirmed` reservations for the current month, ordered by their age. Make sure to display their ages in the result.

SCHEMA:
| table_catalog   | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| arun-genai-bb   | flight_reservations | transactions | transaction_id       | transaction_id       