![tracker](https://us-central1-vertex-ai-mlops-369716.cloudfunctions.net/pixel-tracking?path=statmike%2Fvertex-ai-mlops%2FApplied+GenAI%2Flegacy&file=Vertex+AI+GenAI+For+BigQuery+Q%26A+-+Overview.ipynb)
<!--- header table --->
<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/statmike/vertex-ai-mlops/blob/main/Applied%20GenAI/legacy/Vertex%20AI%20GenAI%20For%20BigQuery%20Q%26A%20-%20Overview.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo">
      <br>Run in<br>Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fstatmike%2Fvertex-ai-mlops%2Fmain%2FApplied%2520GenAI%2Flegacy%2FVertex%2520AI%2520GenAI%2520For%2520BigQuery%2520Q%2526A%2520-%2520Overview.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo">
      <br>Run in<br>Colab Enterprise
    </a>
  </td>      
  <td style="text-align: center">
    <a href="https://github.com/statmike/vertex-ai-mlops/blob/main/Applied%20GenAI/legacy/Vertex%20AI%20GenAI%20For%20BigQuery%20Q%26A%20-%20Overview.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      <br>View on<br>GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/statmike/vertex-ai-mlops/main/Applied%20GenAI/legacy/Vertex%20AI%20GenAI%20For%20BigQuery%20Q%26A%20-%20Overview.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      <br>Open in<br>Vertex AI Workbench
    </a>
  </td>
</table>

---

**File Move Notices**

This file moved locations:
- On 09/07/2024 (mm/dd/yyyy)
	- From: `Applied GenAI/Vertex AI GenAI For BigQuery Q&A - Overview.ipynb`
	- To: `Applied GenAI/legacy/Vertex AI GenAI For BigQuery Q&A - Overview.ipynb`
---
<!---end of move notices--->

# Answering Questions Using BigQuery Tables As Context

Ask Question of a table in BigQuery.  How?  First translated the quesiton to SQL and extract results from the BigQuery Table.  Then provide the results as context with the orignal question to an LLM for answering.

The workflow:
- Setup Enviornment
- Setup access to LLMs: Test and Code
- Retrieve Table Schemas using BigQuery Information Schema
- Translate Question to Code (SQL) with context of table schemas
- Ask an LLM to answer the question with context of results from running the generated SQL
- Ask more question!

---
## Overview

<p><center>
    <img alt="Overview Chart" src="../../architectures/notebooks/applied/genai/bq_qa.png" width="55%">
</center><p>


---
## Colab Setup

To run this notebook in Colab click [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/statmike/vertex-ai-mlops/blob/main/Applied%20GenAI/Vertex%20AI%20GenAI%20For%20BigQuery%20Q&A%20-%20Overview.ipynb) and run the cells in this section.  Otherwise, skip this section.

This cell will authenticate to GCP (follow prompts in the popup).

In [1]:
PROJECT_ID = 'statmike-mlops-349915' # replace with project ID

In [2]:
try:
    import google.colab
    from google.colab import auth
    auth.authenticate_user()
    !gcloud config set project {PROJECT_ID}
except Exception:
    pass

---
## Installs and API Enablement

The clients packages may need installing in this environment.  Also, the APIs for Cloud Speech-To-Text and Cloud Text-To-Speech need to be enabled (if not already enabled).

### Installs (If Needed)

In [3]:
install = False
try: import google.cloud.aiplatform
except ImportError:
    print('You need to pip install google-cloud-aiplatform (VERTEX AI), ... commencing')
    !pip install google-cloud-aiplatform -U -q
    install = True

### Restart Kernel (If Installs Occured)

After a kernel restart the code submission can start with the next cell after this one.

In [4]:
if install:
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

---
## Setup

Inputs

In [5]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [6]:
REGION = 'us-central1'
EXPERIMENT = 'bq-citibikes'
SERIES = 'applied-genai'

Packages

In [7]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import vertexai.language_models
import vertexai.preview.generative_models

from google.cloud import bigquery

Clients

In [8]:
# vertex ai clients
vertexai.init(project = PROJECT_ID, location = REGION)

# bigquery client
bq = bigquery.Client(project = PROJECT_ID)

---
## Goal

New York City has [Citibike](https://citibikenyc.com/) stations where you can rent a bicycle by the ride, the day, or subscribe monthly/annually.  There is a sample of the usage of citibike stations in BigQuery public datasets.  We would like to answer possible questions in natural langauge using this data as the source.

A possible quesiton is: Which Citibike station has the most rental during July 2015?

The appoach used here is ask an LLM to answer the question.  The approach has multiple steps:
1. Ask a code generation LLM to write a SQL query that retrives the relevant information to the question from the tables - the context
2. Run the query generated in 1
3. Ask a text generation LLM to answer the question and give it a context to help accurately answer the question - the result of 2

In [9]:
question = "Which station had most rentals (longest total duration) during July 2015?"

---
## Vertex LLM Setup

- CodeGenerationgModel [Guide](https://cloud.google.com/vertex-ai/docs/generative-ai/code/code-generation-prompts)
    - CodeGenerationModel [API](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.language_models.CodeGenerationModel)
- TextGenerationModel [Guide](https://cloud.google.com/vertex-ai/docs/generative-ai/text/test-text-prompts)
    - TextGenerationModel [API](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.language_models.TextGenerationModel)

In [36]:
# create links to model: embedding api and text generation
textgen_model = vertexai.language_models.TextGenerationModel.from_pretrained('text-bison@002')
codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison@002')
gemini_model = vertexai.preview.generative_models.GenerativeModel("gemini-pro")

Test test generation (llm) model:

In [11]:
textgen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")

 ```sql
SELECT station_id, SUM(duration) AS total_duration
FROM (
  SELECT 
    start_station_id AS station_id,
    start_time,
    DATEDIFF(end_time, start_time) AS duration
  FROM bike_rentals
  WHERE YEAR(start_time) = 2015 AND MONTH(start_time) = 7
)
GROUP BY station_id
ORDER BY total_duration DESC
LIMIT 1;
```

In [12]:
codegen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")

```sql
SELECT station_id, SUM(duration) AS total_duration
FROM (
  SELECT 
    start_station_id AS station_id,
    strftime('%Y-%m', start_time) AS rental_month,
    duration
  FROM bike_rentals
  WHERE start_time BETWEEN '2015-07-01' AND '2015-07-31'
)
GROUP BY station_id
ORDER BY total_duration DESC
LIMIT 1;
```

In [13]:
print(gemini_model.generate_content(f"Write a Google SQL query that answers the following question.\nquestion: {question}", generation_config = dict(temperature=0)).text)

```sql
SELECT station_id, SUM(duration) AS total_duration
FROM (
  SELECT 
    *,
    DATE(start_time) AS rental_date
  FROM bike_rentals
  WHERE rental_date BETWEEN '2015-07-01' AND '2015-07-31'
)
GROUP BY station_id
ORDER BY total_duration DESC
LIMIT 1;
```


These each write code but notice how ithey make assumptions about table names and column names.  The following section address how to guide the LLM to write runnable SQL with correct tables and columns.

---
## The Problem
The LLMs write valid SQL queries.  However, notice that asking an LLM to write the query this way faces several issues:
- The queries do not reference the correct tables
- The column names are not the correct ones from the correct tables

Basically, the generated SQL is a good starting point for a user to write a query that would retrieve the valid context for the users question.

**How to get a fully executable SQL query from the LLM?**

The following approach was created by iteratively refining text prompts and approaches for specific questions.  

---
## Retrieve Table Schemas

The context that will be provided to the LLM to help write the SQL query will be the related tables schema.  The BigQuery tables used for this experiment are BigQuery public dataset tables:
- `bigquery-public-data.new_york.citibike_trips`
- `bigquery-public-data.new_york.citibike_stations`

To retrive the schemas for the tables the BigQuery [INFORMATION_SCHEMA](https://cloud.google.com/bigquery/docs/information-schema-intro) is used - specifically the [INFORMATION_SCHEMA.COLUMN_FIELD_PATHS](https://cloud.google.com/bigquery/docs/information-schema-column-field-paths) view.

In [16]:
BQ_PROJECT = 'bigquery-public-data'
BQ_DATASET = 'new_york'
BQ_TABLES = ['citibike_trips', 'citibike_stations']

In [17]:
query = f"""
    SELECT * EXCEPT(field_path, collation_name, rounding_mode)
    FROM `{BQ_PROJECT}.{BQ_DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in BQ_TABLES])})
"""
print(query)
schema_columns = bq.query(query = query).to_dataframe()


    SELECT * EXCEPT(field_path, collation_name, rounding_mode)
    FROM `bigquery-public-data.new_york.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ("citibike_trips","citibike_stations")



In [18]:
schema_columns.head()

Unnamed: 0,table_catalog,table_schema,table_name,column_name,data_type,description
0,bigquery-public-data,new_york,citibike_stations,station_id,STRING,Unique identifier of a station.
1,bigquery-public-data,new_york,citibike_stations,name,STRING,Public name of the station.
2,bigquery-public-data,new_york,citibike_stations,short_name,STRING,"Short name or other type of identifier, as use..."
3,bigquery-public-data,new_york,citibike_stations,latitude,FLOAT64,The latitude of station. The field value must ...
4,bigquery-public-data,new_york,citibike_stations,longitude,FLOAT64,The longitude of station. The field value must...


In [17]:
#schema_columns.to_markdown(index = False)

---
## Notes On Efficient Information Schema Retrieval

In this example the entire schema is used as context.  But what happens when there are many table and many columns?  Eventually the size will exceed the input size of the LLM.  It is also possible to misguide the LLM in the creation of code by supply too much information - especially when on different topics.

Using semantic retrieval with embeddings is a great solution in this situation:
- create embeddings for all table descriptions
- create embeddings for all column descriptions
- create a vector database with indexes for tables and embeddings
- When a new question comes in find the most applicable table(s) and columns:
    - embed the questions
    - do a vector search for matching table(s)
    - do a vector search for matching columns on the already matched table(s)
    - prepare a schema for just the matched tables and columns

---
## Code Generation LLM - With Context

In this attempt, ask the code generation LLM to write valid SQL and also provide it context.  In this case the context is the schema of the tables that are relevant to Citibike rentals.

Turning the table that represent the schema into context is done here by using a conversion to markdown.  This is an area where users can experiment with the format.  JSON, CSV, ....  I like markdown because it includes the column names a single time and is delimited by the header row notation of markdown!

In [18]:
print(question)

Which station had most rentals (longest total duration) during July 2015?


In [19]:
context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

#context_query = textgen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")
#context_query = codegen_model.predict(context_prompt, max_output_tokens = 500)
context_query = gemini_model.generate_content(context_prompt, generation_config = dict(temperature = 0))

print(context_query.text)

```sql
SELECT start_station_name, SUM(tripduration) AS total_rental_duration
FROM `bigquery-public-data.new_york.citibike_trips`
WHERE EXTRACT(MONTH FROM starttime) = 7
  AND EXTRACT(YEAR FROM starttime) = 2015
GROUP BY start_station_name
ORDER BY total_rental_duration DESC
LIMIT 1;
```


Clean the query up by removing the first and last lines so it can be submitted to BigQuery:

In [20]:
context_query = query = '\n'.join(context_query.text.split('\n')[1:-1])
print(context_query)

SELECT start_station_name, SUM(tripduration) AS total_rental_duration
FROM `bigquery-public-data.new_york.citibike_trips`
WHERE EXTRACT(MONTH FROM starttime) = 7
  AND EXTRACT(YEAR FROM starttime) = 2015
GROUP BY start_station_name
ORDER BY total_rental_duration DESC
LIMIT 1;


Use BigQuery to [Dry run](https://cloud.google.com/bigquery/docs/samples/bigquery-query-dry-run#bigquery_query_dry_run-python) the query and see if it has correct syntax:

In [21]:
dry_run = bq.query(context_query, job_config = bigquery.QueryJobConfig(dry_run = True, use_query_cache = False))

In [22]:
dry_run.errors

In [23]:
dry_run.total_bytes_processed

1232565425

Now run the query and return the result to a local dataframe:

In [24]:
context_response = bq.query(context_query).to_dataframe()

In [25]:
context_response

Unnamed: 0,start_station_name,total_rental_duration
0,Central Park S & 6 Ave,18055103


---
## Answer The Question

Now that a valid context has been retrieved from BigQuery it can be passed to a text generation LLM to answer the user questions.

In [26]:
question_prompt = f"""
context (result from BigQuery query):
{context_response.to_markdown(index = False)}

Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.
{question}
"""

question_response = textgen_model.predict(question_prompt)

print(question_response.text)

 Central Park S & 6 Ave


In [27]:
#print(question_prompt)

---
## Put It All Together

Ask a new question and try it out:

In [28]:
question = 'What were the top five stations with most unique trips in July 2015?'

In [29]:
context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

#context_query = textgen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")
#context_query = codegen_model.predict(context_prompt, max_output_tokens = 500)
context_query = gemini_model.generate_content(context_prompt, generation_config = dict(temperature = 0))


context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

In [30]:
print(context_response.to_markdown(index = False))

| start_station_name                |   total_trips |
|:----------------------------------|--------------:|
| Central Park S & 6 Ave            |          3933 |
| Grand Army Plaza & Central Park S |          3290 |
| Broadway & W 60 St                |          3215 |
| Centre St & Chambers St           |          2656 |
| West St & Chambers St             |          2603 |


In [31]:
question_prompt = f"""
context (result from BigQuery query):
{context_response.to_markdown(index = False)}

Answer the following question.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.
{question}
"""

question_response = textgen_model.predict(question_prompt)

print(question_response.text)

 | start_station_name                |   total_trips |
|:----------------------------------|--------------:|
| Central Park S & 6 Ave            |          3933 |
| Grand Army Plaza & Central Park S |          3290 |
| Broadway & W 60 St                |          3215 |
| Centre St & Chambers St           |          2656 |
| West St & Chambers St             |          2603 |


---
## All Together and More Complex

In [32]:
question = 'What were the top stations with most trips started in July 2015 near central park?'

In [33]:
context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

#context_query = textgen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")
#context_query = codegen_model.predict(context_prompt, max_output_tokens = 500, temperature = 0.9)
context_query = gemini_model.generate_content(context_prompt, generation_config = dict(temperature = 0))

context_query = '\n'.join(context_query.text.split('\n')[1:-1])
context_response = bq.query(query = context_query).to_dataframe()

In [34]:
#context_prompt

In [35]:
print(context_query)

SELECT start_station_name, COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips`
WHERE EXTRACT(MONTH FROM starttime) = 7
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND start_station_latitude BETWEEN 40.769 AND 40.783
  AND start_station_longitude BETWEEN -73.984 AND -73.964
GROUP BY start_station_name
ORDER BY num_trips DESC
LIMIT 10;


In [36]:
context_response

Unnamed: 0,start_station_name,num_trips
0,Broadway & W 60 St,8793


In [37]:
question_prompt = f"""
Answer the following question.  Note that the context above is a tabular result returned from a BigQuery query specific to this question.  Do not repeat the question or the context when responding.
{question}

Use this data:
{context_response.to_markdown(index = False)}
"""

question_response = gemini_model.generate_content(question_prompt)

print(question_response.text)

Broadway & W 60 St


---
## When The Generated Query Fails...

The following question leads to an error in the `context_query` execution.  This section covers a method of using a Code Chat LLM to iteratively fix the query based on the error details returned from the job execution in BigQuery.


In [37]:
question = "How many trips were started on a weekend, in the afternoon, during 2015, by a regular rider, who is over the age of 60?"

In [38]:
context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)

In [39]:
context_query

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON trips.start_station_id = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;
```

Clean the query up by removing the first and last lines so it can be submitted to BigQuery:

In [40]:
context_query = '\n'.join(context_query.text.split('\n')[1:-1])
print(context_query)

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON trips.start_station_id = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;


### Detect Errors: Using Dry Run

Use BigQuery to [Dry run](https://cloud.google.com/bigquery/docs/samples/bigquery-query-dry-run#bigquery_query_dry_run-python) the query and see if it has correct syntax:

In [41]:
try:
    bq.query(context_query, job_config = bigquery.QueryJobConfig(dry_run = True, use_query_cache = False))
except Exception as err:
    errors = f"{type(err).__name__} was raised: {err}"
    print(errors)

BadRequest was raised: 400 POST https://bigquery.googleapis.com/bigquery/v2/projects/statmike-mlops-349915/jobs?prettyPrint=false: No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [4:4]

Location: None
Job ID: d9a604e1-b4fc-4b64-833f-a4383149dd98



### Detect Errors: From Query Execution

In [42]:
query_job = bq.query(query = context_query)

In [43]:
type(query_job)

google.cloud.bigquery.job.query.QueryJob

In [44]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [4:4]'}]

### Use A Chat Model To Iteratively Refine A Query:

Chat models, like `codechat-bison@002` are interactive in that they keep a history of the chat session as context for future message interactions.  This is perfect for both generating the initial query and also asking for help with repairing it due to any errors returned from BigQuery.

#### Start A Chat Session With The Same Context (Schema)

In [45]:
codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@002')

In [46]:
codechat = codechat_model.start_chat(
    context = f"""
The BigQuery Environment has tables defined by the follow schema:
{schema_columns.to_markdown(index = False)}

This session is trying to troubleshoot a Google BigQuery SQL query that is being writen to answer the question:
{question}

BigQuery SQL query that needs to be fixed:
{context_query}

Instructions:
As the user provides versions of the query and the errors returned by BigQuery, offer suggestions that fix the errors but it is important that the query still answer the original question.
"""
)

#### Create a Hint From The Error: Dry Run

Use the error detected during dry run to constuct a hint for the LLM:

In [47]:
errors

'BadRequest was raised: 400 POST https://bigquery.googleapis.com/bigquery/v2/projects/statmike-mlops-349915/jobs?prettyPrint=false: No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [4:4]\n\nLocation: None\nJob ID: d9a604e1-b4fc-4b64-833f-a4383149dd98\n'

In [48]:
hint = ''
if errors.rindex('[') and errors.rindex(']'):
    begin = errors.rindex('[') + 1
    end = errors.rindex(']')
    if end > begin and errors[begin:end].index(':'):
        query_index = [int(q) for q in errors[begin:end].split(':')]
        hint += context_query.split('\n')[query_index[0] - 1].strip()
print(hint)

ON trips.start_station_id = stations.station_id


#### Create a Hint From The Error: Query Execution

Use the errors `message` field to retrieve the line of the query related to the error as a hint:

In [49]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [4:4]'}]

In [50]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

ON trips.start_station_id = stations.station_id


#### Fix Query (Try 1)

In [51]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON trips.start_station_id = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;

Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [4:4]'}]

Hint, the error appears to be in this line of the query:
ON trips.start_station_id = stations.station_id



In [52]:
response = codechat.send_message(fix_prompt)

In [53]:
response

 The error message suggests that there is an issue with the equality comparison between `trips.start_station_id` and `stations.station_id`. To fix this, ensure that both columns have compatible data types. In this case, `trips.start_station_id` is an INT64, while `stations.station_id` is a STRING. To resolve this, you can convert `trips.start_station_id` to a STRING using the CAST function, as shown below:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;
```

In [54]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [55]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;
```

In [56]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;



In [57]:
query_job = bq.query(query = context_query)

In [58]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Function not found: strftime at [7:10]'}]

In [59]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

CAST(strftime('%Y', trips.starttime) AS INT64) = 2015


#### Fix Query (Try 2)

In [60]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Function not found: strftime at [7:10]'}]

Hint, the error appears to be in this line of the query:
CAST(strftime('%Y', trips.starttime) AS INT64) = 2015



In [61]:
response = codechat.send_message(fix_prompt)

In [62]:
response

 The `strftime` function is not supported in BigQuery Standard SQL. Instead, you can use the `DATE` function to extract the year from the `starttime` column. The corrected query is:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(DATE(trips.starttime, '%Y') AS INT64) - trips.birth_year > 60;
```

In [63]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [64]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(DATE(trips.starttime, '%Y') AS INT64) - trips.birth_year > 60;
```

In [65]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(DATE(trips.starttime, '%Y') AS INT64) - trips.birth_year > 60;



In [66]:
query_job = bq.query(query = context_query)

In [67]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Invalid cast from DATE to INT64 at [7:10]'}]

In [68]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015


#### Fix Query (Try 3)

In [69]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(DATE(trips.starttime, '%Y') AS INT64) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Invalid cast from DATE to INT64 at [7:10]'}]

Hint, the error appears to be in this line of the query:
CAST(DATE(trips.starttime, '%Y') AS INT64) = 2015



In [70]:
response = codechat.send_message(fix_prompt)

In [71]:
response

 The error message suggests that there is an issue with casting the result of the `DATE` function to INT64. To fix this, you can use the `EXTRACT` function to extract the year from the `starttime` column, as shown below:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [72]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [73]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [74]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;



In [75]:
query_job = bq.query(query = context_query)

In [76]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Function not found: strftime at [8:9]'}]

In [77]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday


#### Fix Query (Try 3)

In [78]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Function not found: strftime at [8:9]'}]

Hint, the error appears to be in this line of the query:
AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday



In [79]:
response = codechat.send_message(fix_prompt)

In [80]:
response

 The `strftime` function is not supported in BigQuery Standard SQL. Instead, you can use the `WEEKDAY` function to extract the day of the week from the `starttime` column. The corrected query is:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [81]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [82]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [83]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;



In [84]:
query_job = bq.query(query = context_query)

In [85]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Function not found: WEEKDAY at [8:9]'}]

In [86]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday


#### Fix Query (Try 4)

In [87]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Function not found: WEEKDAY at [8:9]'}]

Hint, the error appears to be in this line of the query:
AND WEEKDAY(trips.starttime) IN (0, 6) -- Sunday or Saturday



In [88]:
response = codechat.send_message(fix_prompt)

In [89]:
response

 The `WEEKDAY` function is not supported in BigQuery Standard SQL. Instead, you can use the `DAYOFWEEK` function to extract the day of the week from the `starttime` column. The corrected query is:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [90]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [91]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [92]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;



In [93]:
query_job = bq.query(query = context_query)

In [94]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Function not found: DAYOFWEEK at [8:9]'}]

In [95]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday


#### Fix Query (Try 5)

In [96]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Function not found: DAYOFWEEK at [8:9]'}]

Hint, the error appears to be in this line of the query:
AND DAYOFWEEK(trips.starttime) IN (0, 6) -- Sunday or Saturday



In [97]:
response = codechat.send_message(fix_prompt)

In [98]:
response

 The `DAYOFWEEK` function is not supported in BigQuery Standard SQL. Instead, you can use the `EXTRACT` function to extract the day of the week from the `starttime` column. The corrected query is:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [99]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [100]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [101]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;



In [102]:
query_job = bq.query(query = context_query)

In [103]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'Function not found: strftime at [9:9]'}]

In [104]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += context_query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM


#### Fix Query (Try 6)

In [105]:
fix_prompt = f"""
This query:
{context_query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'Function not found: strftime at [9:9]'}]

Hint, the error appears to be in this line of the query:
AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM



In [106]:
response = codechat.send_message(fix_prompt)

In [107]:
response

 The `strftime` function is not supported in BigQuery Standard SQL. Instead, you can use the `FORMAT_TIMESTAMP` function to extract the hour from the `starttime` column. The corrected query is:

```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND FORMAT_TIMESTAMP('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [108]:
context_query = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [109]:
context_query

 ```sql
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND FORMAT_TIMESTAMP('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;
```

In [110]:
if context_query.text.find("```") >= 0:
    context_query = response.text.split("```")[1]
    if context_query.startswith('sql'): context_query = context_query[3:]
    print(context_query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    EXTRACT(YEAR FROM trips.starttime) = 2015
    AND EXTRACT(DAYOFWEEK FROM trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND FORMAT_TIMESTAMP('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year > 60;



In [111]:
query_job = bq.query(query = context_query)

In [112]:
if query_job.errors:
    hint = ''
    for error in query_job.errors:
        # detect error message
        if 'message' in list(error.keys()):
            # detect index of error location
            if error['message'].rindex('[') and error['message'].rindex(']'):
                begin = error['message'].rindex('[') + 1
                end = error['message'].rindex(']')
                # verify that it looks like an error location:
                if end > begin and error['message'][begin:end].index(':'):
                    # retrieve the two parts of the error index: query line, query column
                    query_index = [int(q) for q in error['message'][begin:end].split(':')]
                    hint += context_query.split('\n')[query_index[0]-1].strip()
                    break
    print(hint)
else:
    response = query_job.to_dataframe()
    print(response)

   num_trips
0      18968


#### Generate Response

In [113]:
question_prompt = f"""
Answer the following question.  Note that the context is a tabular result returned from a BigQuery query specific to this question.  Do not repeat the question or the context when responding.
{question}

Use this data:
{response.to_markdown(index = False)}
"""

question_response = gemini_model.generate_content(question_prompt)

print(question_response.text)

18968


## Put It All Together

Build functions that generate the starting query, initialize a chat debug session, and manage the error detection and fixing calls to the chat debug.

### Functions To Answer The Question Using Iteration To Fix Context Query

In [125]:
def initial_query(question, schema_columns):
    
    # code generation model
    codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison@002')
    
    # initial request for query:
    context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

    context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)
    
    # extract query from response
    if context_query.text.find("```") >= 0:
        context_query = context_query.text.split("```")[1]
        if context_query.startswith('sql'):
            context_query = context_query[3:]
        print('Initial Query:\n', context_query)
    else:
        print('No query provided (first try) - unforseen error, printing out response to help with editing this funcntion:\n', query_response.text)
    
    return context_query    

In [126]:
def codechat_start(question, query, schema_columns):

    # code chat model
    codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@002')
    
    # start a code chat session and give the schema for columns as the starting context:
    codechat = codechat_model.start_chat(
        context = f"""
The BigQuery Environment has tables defined by the follow schema:
{schema_columns.to_markdown(index = False)}

This session is trying to troubleshoot a Google BigQuery SQL query that is being writen to answer the question:
{question}

BigQuery SQL query that needs to be fixed:
{query}

Instructions:
As the user provides versions of the query and the errors returned by BigQuery, offer suggestions that fix the errors but it is important that the query still answer the original question.
"""
    )
    
    return codechat

In [127]:
def fix_query(query, max_fixes):
    
    # iteratively run query, and fix it using codechat until success (or max_fixes reached):
    fix_tries = 0
    answer = False
    while fix_tries < max_fixes:
        if not query: 
            return
        # run query:
        query_job = bq.query(query = query)
        # if errors, then generate repair query:
        if query_job.errors:
            fix_tries += 1
            
            if fix_tries == 1:
                codechat = codechat_start(question, query, schema_columns)
            
            # construct hint from error
            hint = ''
            for error in query_job.errors:
                # detect error message
                if 'message' in list(error.keys()):
                    # detect index of error location
                    if error['message'].rindex('[') and error['message'].rindex(']'):
                        begin = error['message'].rindex('[') + 1
                        end = error['message'].rindex(']')
                        # verify that it looks like an error location:
                        if end > begin and error['message'][begin:end].index(':'):
                            # retrieve the two parts of the error index: query line, query column
                            query_index = [int(q) for q in error['message'][begin:end].split(':')]
                            hint += query.split('\n')[query_index[0]-1].strip()
                            break
            
            # construct prompt to request a fix:
            fix_prompt = f"""This query:\n{query}\n\nReturns these errors:\n{query_job.errors}"""

            if hint != '':
                fix_prompt += f"""\n\nHint, the error appears to be in this line of the query:\n{hint}"""            
            
            query_response = codechat.send_message(fix_prompt)
            query_response = codechat.send_message('Respond with only the corrected query that still answers the question as a markdown code block.')
            if query_response.text.find("```") >= 0:
                query = query_response.text.split("```")[1]
                if query.startswith('sql'):
                    query = query[4:]
                print(f'Fix #{fix_tries}:\n', query)
            # response did not have a query????:
            else:
                query = ''
                print('No query in response...')

        # no error, break while loop
        else:
            break
    
    return query, query_job, fix_tries, codechat

In [128]:
def answer_question(question, query_job):

    # text generation model
    gemini_model = vertexai.preview.generative_models.GenerativeModel("gemini-pro")

    # answer question
    result = query_job.to_dataframe()
    question_prompt = f"""
Answer the following question.  Note that the context is a tabular result returned from a BigQuery query specific to this question.  Do not repeat the question or the context when responding.
{question}

Use this data:
{result.to_markdown(index = False)}
    """

    question_response = gemini_model.generate_content(question_prompt)
    
    return question_response.text

In [129]:
def BQ_QA(question, max_fixes = 10, schema_columns = schema_columns):
    
    # generate query
    query = initial_query(question, schema_columns)
    
    # run query:
    query_job = bq.query(query = query)
    # if errors, then generate repair query:
    if query_job.errors:
        print('found errors')
        query, query_job, fix_tries, codechat = fix_query(query, max_fixes)
    
    # respond with outcome:
    if query_job.errors:
        print(f'No answer generated after {fix_tries} tries.')
        return codechat
    else:
        question_response = answer_question(question, query_job)
        print(question_response)
        try:
            return codechat
        except:
            return None

### Answer The Same Question That Had Failures:

In [130]:
question

'How many trips were started on a weekend, in the afternoon, during 2015, by a regular rider, who is over the age of 60?'

In [131]:
session = BQ_QA(question)

Initial Query:
 
SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON trips.start_station_id = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;

found errors
Fix #1:
 SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '

In [132]:
for message in session.message_history:
    print(message.content)

This query:

SELECT COUNT(*) AS num_trips
FROM `bigquery-public-data.new_york.citibike_trips` AS trips
JOIN `bigquery-public-data.new_york.citibike_stations` AS stations
ON trips.start_station_id = stations.station_id
WHERE 
    CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
    AND strftime('%w', trips.starttime) IN (0, 6) -- Sunday or Saturday
    AND strftime('%H', trips.starttime) BETWEEN '12' AND '17' -- 12:00 PM - 5:00 PM
    AND trips.usertype = 'Subscriber'
    AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year > 60;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'No matching signature for operator = for argument types: INT64, STRING. Supported signature: ANY = ANY at [5:4]'}]

Hint, the error appears to be in this line of the query:
ON trips.start_station_id = stations.station_id
 The error message suggests that there is a type mismatch between the `start_station_id` column in the `trips` table and the `station_id`

### Try Another Question:

In [133]:
session = BQ_QA('How many trips during the year 2015: started in the evening, were over an hour long, and were by regular riders, who were over age 60?')

Initial Query:
 
SELECT
  COUNT(*) AS total_trips
FROM
  `bigquery-public-data.new_york.citibike_trips` AS trips
WHERE
  CAST(strftime('%H', trips.starttime) AS INT64) BETWEEN 17 AND 23
  AND trips.tripduration >= 3600
  AND trips.usertype = 'Subscriber'
  AND CAST(strftime('%Y', trips.starttime) AS INT64) = 2015
  AND CAST(strftime('%Y', trips.starttime) AS INT64) - trips.birth_year >= 60;

found errors
Fix #1:
 SELECT
  COUNT(*) AS total_trips
FROM
  `bigquery-public-data.new_york.citibike_trips` AS trips
WHERE
  EXTRACT(HOUR FROM trips.starttime) BETWEEN 17 AND 23
  AND trips.tripduration >= 3600
  AND trips.usertype = 'Subscriber'
  AND EXTRACT(YEAR FROM trips.starttime) = 2015
  AND EXTRACT(YEAR FROM trips.starttime) - trips.birth_year >= 60;

529


### Ideas For Improvement

After a few attempts to fix, ask the CodeChat LLM to start fresh and write a new query to answer then question.  The iterate on the new query.

If the fixes are actually breaking the logic of the query try asking if the successful query answers the questions before providing the result.