![ga4](https://www.google-analytics.com/collect?v=2&tid=G-6VDTYWLKX6&cid=1&en=page_view&sid=1&dl=statmike%2Fvertex-ai-mlops%2FApplied+GenAI&dt=Vertex+AI+GenAI+For+BigQuery+Q%26A+-+Overview.ipynb)

# Answering Questions Using BigQuery Tables As Context

Ask Question of a table in BigQuery.  How?  First translated the quesiton to SQL and extract results from the BigQuery Table.  Then provide the results as context with the orignal question to an LLM for answering.

The workflow:
- Setup Enviornment
- Setup access to LLMs: Test and Code
- Retrieve Table Schemas using BigQuery Information Schema
- Translate Question to Code (SQL) with context of table schemas
- Ask an LLM to answer the question with context of results from running the generated SQL
- Ask more question!

---
## Overview

<p><center>
    <img alt="Overview Chart" src="../architectures/notebooks/applied/genai/bq_qa.png" width="55%">
</center><p>


---
## Colab Setup

To run this notebook in Colab click [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/statmike/vertex-ai-mlops/blob/main/Applied%20GenAI/Vertex%20AI%20GenAI%20For%20BigQuery%20Q&A%20-%20Overview.ipynb) and run the cells in this section.  Otherwise, skip this section.

This cell will authenticate to GCP (follow prompts in the popup).

In [1]:
PROJECT_ID = 'statmike-mlops-349915' # replace with project ID

In [2]:
try:
    import google.colab
    from google.colab import auth
    auth.authenticate_user()
    !gcloud config set project {PROJECT_ID}
except Exception:
    pass

---
## Installs and API Enablement

The clients packages may need installing in this environment.  Also, the APIs for Cloud Speech-To-Text and Cloud Text-To-Speech need to be enabled (if not already enabled).

### Installs (If Needed)

In [3]:
install = False
try: import google.cloud.aiplatform
except ImportError:
    print('You need to pip install google-cloud-aiplatform (VERTEX AI), ... commencing')
    !pip install google-cloud-aiplatform -U -q
    install = True

### Restart Kernel (If Installs Occured)

After a kernel restart the code submission can start with the next cell after this one.

In [4]:
if install:
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

---
## Setup

Inputs

In [5]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [6]:
REGION = 'us-central1'
EXPERIMENT = 'bq-citibikes'
SERIES = 'applied-genai'

Packages

In [7]:
import vertexai.language_models
from google.cloud import aiplatform
from google.cloud import bigquery

import pandas as pd

2023-10-27 17:42:42.746667: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Clients

In [8]:
# vertex ai clients
vertexai.init(project = PROJECT_ID, location = REGION)
aiplatform.init(project = PROJECT_ID, location = REGION)

# bigquery client
bq = bigquery.Client(project = PROJECT_ID)

---
## Goal

New York City has [Citibike](https://citibikenyc.com/) stations where you can rent a bicycle by the ride, the day, or subscribe monthly/annually.  There is a sample of the usage of citibike stations in BigQuery public datasets.  We would like to answer possible questions in natural langauge using this data as the source.

A possible quesiton is: Which Citibike station has the most rental during July 2015?

The appoach used here is ask an LLM to answer the question.  The approach has multiple steps:
1. Ask a code generation LLM to write a SQL query that retrives the relevant information to the question from the tables - the context
2. Run the query generated in 1
3. Ask a text generation LLM to answer the question and give it a context to help accurately answer the question - the result of 2

In [9]:
question = "Which station had most rentals (longest total duration) during July 2015?"

---
## Vertex LLM Setup

- CodeGenerationgModel [Guide](https://cloud.google.com/vertex-ai/docs/generative-ai/code/code-generation-prompts)
    - CodeGenerationModel [API](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.language_models.CodeGenerationModel)
- TextGenerationModel [Guide](https://cloud.google.com/vertex-ai/docs/generative-ai/text/test-text-prompts)
    - TextGenerationModel [API](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.language_models.TextGenerationModel)

In [10]:
# create links to model: embedding api and text generation
textgen_model = vertexai.language_models.TextGenerationModel.from_pretrained('text-bison')
codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison')

Test test generation (llm) model:

In [11]:
textgen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")

 ```sql
SELECT start_station_id,
       SUM(duration_minutes) AS total_duration
FROM bike_sharing_trips
WHERE EXTRACT(MONTH FROM start_time) = 7
  AND EXTRACT(YEAR FROM start_time) = 2015
GROUP BY start_station_id
ORDER BY total_duration DESC
LIMIT 1;
```

In [12]:
codegen_model.predict(f"Write a Google SQL query that answers the following question.\nquestion: {question}")

```sql
SELECT start_station_id,
       SUM(duration_minutes) AS total_duration
FROM bike_sharing_trips
WHERE DATE BETWEEN '2015-07-01' AND '2015-07-31'
GROUP BY start_station_id
ORDER BY total_duration DESC
LIMIT 1;
```

These both write code but notice how it has to make assumptions about table names and column names.  The following section address how to guide the LLM to write runnable SQL with correct tables and columns.

---
## The Problem
Both LLMs write valid SQL queries.  However, notice that asking either LLM to write the query this way faces several issues:
- The queries do not reference the correct tables
- The column names are not the correct ones from the correct tables

Basically, the generated SQL is a good starting point for a user to write a query that would retrieve the valid context for the users question.

**How to get a fully executable SQL query from the LLM?**

The following approach was created by iteratively refining text prompts and approaches for specific questions.  

---
## Retrieve Table Schemas

The context that will be provided to the LLM to help write the SQL query will be the related tables schema.  The BigQuery tables used for this experiment are BigQuery public dataset tables:
- `bigquery-public-data.new_york.citibike_trips`
- `bigquery-public-data.new_york.citibike_stations`

To retrive the schemas for the tables the BigQuery [INFORMATION_SCHEMA](https://cloud.google.com/bigquery/docs/information-schema-intro) is used - specifically the [INFORMATION_SCHEMA.COLUMN_FIELD_PATHS](https://cloud.google.com/bigquery/docs/information-schema-column-field-paths) view.

In [13]:
BQ_PROJECT = 'bigquery-public-data'
BQ_DATASET = 'new_york'
BQ_TABLES = ['citibike_trips', 'citibike_stations']

In [14]:
query = f"""
    SELECT *
    FROM `{BQ_PROJECT}.{BQ_DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in BQ_TABLES])})
"""
print(query)
schema_columns = bq.query(query = query).to_dataframe()


    SELECT *
    FROM `bigquery-public-data.new_york.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ("citibike_trips","citibike_stations")



In [15]:
schema_columns.head()

Unnamed: 0,table_catalog,table_schema,table_name,column_name,field_path,data_type,description,collation_name,rounding_mode
0,bigquery-public-data,new_york,citibike_stations,station_id,station_id,STRING,Unique identifier of a station.,,
1,bigquery-public-data,new_york,citibike_stations,name,name,STRING,Public name of the station.,,
2,bigquery-public-data,new_york,citibike_stations,short_name,short_name,STRING,"Short name or other type of identifier, as use...",,
3,bigquery-public-data,new_york,citibike_stations,latitude,latitude,FLOAT64,The latitude of station. The field value must ...,,
4,bigquery-public-data,new_york,citibike_stations,longitude,longitude,FLOAT64,The longitude of station. The field value must...,,


In [16]:
#schema_columns.to_markdown(index = False)

---
## Notes On Efficient Information Schema Retrieval

In this example the entire schema is used as context.  But what happens when there are many table and many columns?  Eventually the size will exceed the input size of the LLM.  It is also possible to misguide the LLM in the creation of code by supply too much information - especially when on different topics.

Using semantic retrieval with embeddings is a great solution in this situation:
- create embeddings for all table descriptions
- create embeddings for all column descriptions
- create a vector database with indexes for tables and embeddings
- When a new question comes in find the most applicable table(s) and columns:
    - embed the questions
    - do a vector search for matching table(s)
    - do a vector search for matching columns on the already matched table(s)
    - prepare a schema for just the matched tables and columns

---
## Code Generation LLM - With Context

In this attempt, ask the code generation LLM to write valid SQL and also provide it context.  In this case the context is the schema of the tables that are relevant to Citibike rentals.

Turning the table that represent the schema into context is done here by using a conversion to markdown.  This is an area where users can experiment with the format.  JSON, CSV, ....  I like markdown because it includes the column names a single time and is delimited by the header row notation of markdown!

In [17]:
print(question)

Which station had most rentals (longest total duration) during July 2015?


In [18]:
context_prompt = f"""
Write a Google SQL query for BigQuery that answers the following question while using the provided context to correctly refer to BigQuery tables and the needed column names.  When joining tables use coersion to ensure all join columns are the same data type. Output column names should include the units when applicable.  Tables should be refered to using a fully qualified name include project and dataset along with table name. 
question:
{question}

context:
{schema_columns.to_markdown(index = False)}
"""

context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)

print(context_query.text)

```sql
SELECT 
  start_station_name AS station_name,
  SUM(tripduration) AS total_duration_minutes
FROM bigquery-public-data.new_york.citibike_trips
WHERE 
  EXTRACT(MONTH FROM starttime) = 7
  AND EXTRACT(YEAR FROM starttime) = 2015
GROUP BY 1
ORDER BY 2 DESC
LIMIT 1;
```


In [19]:
#print(context_prompt)

In [20]:
context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

In [21]:
context_response

Unnamed: 0,station_name,total_duration_minutes
0,Central Park S & 6 Ave,18055103


---
## Answer The Question

Now that a valid context has been retrieved from BigQuery it can be passed to a text generation LLM to answer the user questions.

In [22]:
question_prompt = f"""
Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.

question:
{question}
context:
{context_response.to_markdown(index = False)}
"""

question_response = textgen_model.predict(question_prompt)

print(question_response.text)

 Central Park S & 6 Ave


In [23]:
#print(question_prompt)

---
## Put It All Together

Ask a new question and try it out:

In [24]:
question = 'What were the top five stations with most unique trips in July 2015?'

In [25]:
context_prompt = f"""
Write a Google SQL query for BigQuery that answers the following question while using the provided context to correctly refer to BigQuery tables and the needed column names.  When joining tables use coersion to ensure all join columns are the same data type. Output column names should include the units when applicable.  Tables should be refered to using a fully qualified name include project and dataset along with table name. 
question:
{question}

context:

{schema_columns.to_markdown(index = False)}
"""

context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)

context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

In [28]:
print(context_response.to_markdown(index = False))

| station_name          |   unique_trips |
|:----------------------|---------------:|
| 8 Ave & W 31 St       |           6004 |
| Pershing Square North |           5468 |
| West St & Chambers St |           5259 |
| Lafayette St & E 8 St |           5144 |
| E 17 St & Broadway    |           5116 |


In [29]:
question_prompt = f"""
Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.

question:
{question}
context:
{context_response.to_markdown(index = False)}
"""

question_response = textgen_model.predict(question_prompt)

print(question_response.text)

 1. 8 Ave & W 31 St (6004)
2. Pershing Square North (5468)
3. West St & Chambers St (5259)
4. Lafayette St & E 8 St (5144)
5. E 17 St & Broadway (5116)


---
## All Together and More Complex

In [33]:
question = 'What were the top five stations with most trips started in July 2015 near central park?'

In [34]:
context_prompt = f"""
Write a Google SQL query for BigQuery that answers the following question while using the provided context to correctly refer to BigQuery tables and the needed column names.  When joining tables use coersion to ensure all join columns are the same data type. Output column names should include the units when applicable.  Tables should be refered to using a fully qualified name include project and dataset along with table name. 
question:
{question}

context:

{schema_columns.to_markdown(index = False)}
"""

context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)

context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

In [35]:
context_query

```sql
SELECT
  start_station_name AS station_name,
  COUNT(*) AS num_trips
FROM
  bigquery-public-data.new_york.citibike_trips
WHERE
  EXTRACT(MONTH FROM starttime) = 7
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND start_station_latitude BETWEEN 40.769 AND 40.784
  AND start_station_longitude BETWEEN -73.984 AND -73.964
GROUP BY
  start_station_name
ORDER BY
  num_trips DESC
LIMIT
  5;
```

In [37]:
question_prompt = f"""
Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.

Include popular points of interest near each listed station.

question:
{question}
context:
{context_response.to_markdown(index = False)}
"""

question_response = textgen_model.predict(question_prompt, max_output_tokens = 500)

print(question_response.text)

 The top five stations with the most trips started in July 2015 near Central Park are:

1. Broadway & W 60 St (8793 trips) - Central Park Zoo, Lincoln Center, Time Warner Center
2. 5 Ave/59 St       (7602 trips) - Central Park Zoo, The Plaza Hotel, FAO Schwarz
3. Columbus Circle   (6841 trips) - Central Park South, Museum of Modern Art, Time Warner Center
4. 8 Ave/W 59 St     (6311 trips) - Central Park South, Lincoln Center, Hearst Tower
5. 7 Ave/W 59 St     (5993 trips) - Central Park South, The Plaza Hotel, FAO Schwarz


---
## When The Generated Query Fails...

The following question leads to an error in the `context_query` execution.  This section covers a method of using a Code Chat LLM to iteratively fix the query based on the error details returned from the job execution in BigQuery.


In [291]:
question = "How many trips were started on a weekend, in the afternoon, during 2015, by a regular rider, who is over the age of 60?"

In [40]:
context_prompt = f"""
Write a Google SQL query for BigQuery that answers the following question while using the provided context to correctly refer to BigQuery tables and the needed column names.  When joining tables use coersion to ensure all join columns are the same data type. Output column names should include the units when applicable.  Tables should be refered to using a fully qualified name include project and dataset along with table name. 
question:
{question}

context:

{schema_columns.to_markdown(index = False)}
"""

context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)

In [41]:
context_query

```sql
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  EXTRACT(WEEKDAY FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND usertype = 'Subscriber'
  AND birth_year < 1955;
```

In [42]:
# uncomment an run this query - it will fail - due to WEEKDAY is not a valid date part - 10/17/2023
#context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

```
---------------------------------------------------------------------------
BadRequest                                Traceback (most recent call last)
Cell In[325], line 1
----> 1 context_response = bq.query(query = '\n'.join(context_query.text.split('\n')[1:-1])).to_dataframe()

File /opt/conda/lib/python3.10/site-packages/google/cloud/bigquery/job/query.py:1799, in QueryJob.to_dataframe(self, bqstorage_client, dtypes, progress_bar_type, create_bqstorage_client, max_results, geography_as_object, bool_dtype, int_dtype, float_dtype, string_dtype, date_dtype, datetime_dtype, time_dtype, timestamp_dtype)
   1633 def to_dataframe(
   1634     self,
   1635     bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
   (...)
   1648     timestamp_dtype: Union[Any, None] = None,
   1649 ) -> "pandas.DataFrame":
   1650     """Return a pandas DataFrame from a QueryJob
   1651 
   1652     Args:
   (...)
   1797             :mod:`shapely` library cannot be imported.
   1798     """
-> 1799     query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
   1800     return query_result.to_dataframe(
   1801         bqstorage_client=bqstorage_client,
   1802         dtypes=dtypes,
   (...)
   1813         timestamp_dtype=timestamp_dtype,
   1814     )

File /opt/conda/lib/python3.10/site-packages/google/cloud/bigquery/_tqdm_helpers.py:104, in wait_for_query(query_job, progress_bar_type, max_results)
    100 progress_bar = get_progress_bar(
    101     progress_bar_type, "Query is running", default_total, "query"
    102 )
    103 if progress_bar is None:
--> 104     return query_job.result(max_results=max_results)
    106 i = 0
    107 while True:

File /opt/conda/lib/python3.10/site-packages/google/cloud/bigquery/job/query.py:1520, in QueryJob.result(self, page_size, max_results, retry, timeout, start_index, job_retry)
   1517     if retry_do_query is not None and job_retry is not None:
   1518         do_get_result = job_retry(do_get_result)
-> 1520     do_get_result()
   1522 except exceptions.GoogleAPICallError as exc:
   1523     exc.message = _EXCEPTION_FOOTER_TEMPLATE.format(
   1524         message=exc.message, location=self.location, job_id=self.job_id
   1525     )

File /opt/conda/lib/python3.10/site-packages/google/api_core/retry.py:349, in Retry.__call__.<locals>.retry_wrapped_func(*args, **kwargs)
    345 target = functools.partial(func, *args, **kwargs)
    346 sleep_generator = exponential_sleep_generator(
    347     self._initial, self._maximum, multiplier=self._multiplier
    348 )
--> 349 return retry_target(
    350     target,
    351     self._predicate,
    352     sleep_generator,
    353     self._timeout,
    354     on_error=on_error,
    355 )

File /opt/conda/lib/python3.10/site-packages/google/api_core/retry.py:191, in retry_target(target, predicate, sleep_generator, timeout, on_error, **kwargs)
    189 for sleep in sleep_generator:
    190     try:
--> 191         return target()
    193     # pylint: disable=broad-except
    194     # This function explicitly must deal with broad exceptions.
    195     except Exception as exc:

File /opt/conda/lib/python3.10/site-packages/google/cloud/bigquery/job/query.py:1510, in QueryJob.result.<locals>.do_get_result()
   1507     self._retry_do_query = retry_do_query
   1508     self._job_retry = job_retry
-> 1510 super(QueryJob, self).result(retry=retry, timeout=timeout)
   1512 # Since the job could already be "done" (e.g. got a finished job
   1513 # via client.get_job), the superclass call to done() might not
   1514 # set the self._query_results cache.
   1515 self._reload_query_results(retry=retry, timeout=timeout)

File /opt/conda/lib/python3.10/site-packages/google/cloud/bigquery/job/base.py:922, in _AsyncJob.result(self, retry, timeout)
    919     self._begin(retry=retry, timeout=timeout)
    921 kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry}
--> 922 return super(_AsyncJob, self).result(timeout=timeout, **kwargs)

File /opt/conda/lib/python3.10/site-packages/google/api_core/future/polling.py:261, in PollingFuture.result(self, timeout, retry, polling)
    256 self._blocking_poll(timeout=timeout, retry=retry, polling=polling)
    258 if self._exception is not None:
    259     # pylint: disable=raising-bad-type
    260     # Pylint doesn't recognize that this is valid in this case.
--> 261     raise self._exception
    263 return self._result

BadRequest: 400 A valid date part name is required but found WEEKDAY at [6:11]

Location: US
Job ID: 8bd89480-8117-4c27-aacc-c94affe146ca
```

### Detect Errors

In [189]:
query = '\n'.join(context_query.text.split('\n')[1:-1])

In [190]:
print(query)

SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  EXTRACT(WEEKDAY FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND usertype = 'Subscriber'
  AND birth_year < 1955;


In [191]:
query_job = bq.query(query = query)

In [192]:
type(query_job)

google.cloud.bigquery.job.query.QueryJob

In [193]:
query_job.errors

[{'reason': 'invalidQuery',
  'location': 'query',
  'message': 'A valid date part name is required but found WEEKDAY at [6:11]'}]

### Use A Chat Model To Iteratively Refine A Query:

Chat models, like `codechat-bison@001` are interactive in that they keep a history of the chat session as context for future message interactions.  This is perfect for both generating the initial query and also asking for help with repairing it due to any errors returned from BigQuery.

#### Start A Chat Session With The Same Context (Schema)

In [194]:
codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@001')

In [195]:
codechat = codechat_model.start_chat(
    context = f"""This session is trying to troubleshoot a Google BigQuery SQL query that is being writen to answer the question:
Question:
{question}

BigQuery SQL query that needs to be fixed:
{query}

The BigQuery Environment has tables defined by the follow schema:
{schema_columns.to_markdown(index = False)}

Instructions:
As the user provides versions of the query and the errors returned by BigQuery, offer suggestions that fix the errors but it is important that the query still answer the original question.
"""
)

#### Create a Hint From The Error

Use the errors `message` filed to retrieve the line of the query related to the error as a hint:

In [196]:
hint = ''
for error in query_job.errors:
    # detect error message
    if 'message' in list(error.keys()):
        # detect index of error location
        if error['message'].rindex('[') and error['message'].rindex(']'):
            begin = error['message'].rindex('[') + 1
            end = error['message'].rindex(']')
            # verify that it looks like an error location:
            if end > begin and error['message'][begin:end].index(':'):
                # retrieve the two parts of the error index: query line, query column
                query_index = [int(q) for q in error['message'][begin:end].split(':')]
                hint += query.split('\n')[query_index[0]-1].strip()
                break
print(hint)

EXTRACT(WEEKDAY FROM starttime) IN (0, 6)


#### Fix Query (Try 1)

In [197]:
fix_prompt = f"""
This query:
{query}

Returns these errors:
{query_job.errors}
"""

if hint != '':
    fix_prompt += f"""
Hint, the error appears to be in this line of the query:
{hint}
"""
print(fix_prompt)


This query:
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  EXTRACT(WEEKDAY FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND usertype = 'Subscriber'
  AND birth_year < 1955;

Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'A valid date part name is required but found WEEKDAY at [6:11]'}]

Hint, the error appears to be in this line of the query:
EXTRACT(WEEKDAY FROM starttime) IN (0, 6)



In [198]:
response = codechat.send_message(fix_prompt)

In [199]:
response

The following query fixes the error:

```
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  DATE(starttime) BETWEEN '2015-01-01' AND '2015-12-31'
  AND EXTRACT(DAYOFWEEK FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND usertype = 'Subscriber'
  AND birth_year < 1955;
```

The error was caused by the use of the `WEEKDAY` function, which is not supported in BigQuery. The `DATE` function can be used to extract the date from a timestamp, and the `DAYOFWEEK` function can be used to extract the day of the week from a date.

In [200]:
response = codechat.send_message(f'Respond with only the corrected query as a markdown code block.')

In [201]:
response

```
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  DATE(starttime) BETWEEN '2015-01-01' AND '2015-12-31'
  AND EXTRACT(DAYOFWEEK FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND usertype = 'Subscriber'
  AND birth_year < 1955;
```

In [202]:
if response.text.find("```") >= 0:
    query = response.text.split("```")[1]
    if query.startswith('sql'): query = query[3:]
    print(query)
else:
    print('no query in response')


SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  DATE(starttime) BETWEEN '2015-01-01' AND '2015-12-31'
  AND EXTRACT(DAYOFWEEK FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND usertype = 'Subscriber'
  AND birth_year < 1955;



In [203]:
query_job = bq.query(query = query)

In [204]:
query_job.errors

In [205]:
query_job.to_dataframe()

Unnamed: 0,num_trips
0,18968


#### Generate Response

In [206]:
question_prompt = f"""
Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.

question:
{question}

context:
{query_job.to_dataframe().to_markdown(index = False)}
"""

question_response = textgen_model.predict(question_prompt, max_output_tokens = 500)

print(question_response.text)

 18968


### Put It All Together

In [292]:
question

'How many trips were started on a weekend, in the afternoon, during 2015, by a regular rider, who is over the age of 60?'

#### Functions To Answer The Question Using Iteration To Fix Context Query

In [301]:
def initial_query(question, schema_columns):
    
    # code generation model
    codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison')
    
    # initial request for query:
    query_response = codegen_model.predict(f"""Write a Google SQL query for BigQuery that answers the following question while correctly refering to BigQuery tables and the needed column names.  When joining tables use coersion to ensure all join columns are the same data type. Output column names should include the units when applicable.  Tables should be refered to using a fully qualified name include project and dataset along with table name. 

Question: {question}

Context:
{schema_columns.to_markdown(index = False)}
""")
    
    # extract query from response
    if query_response.text.find("```") >= 0:
        query = query_response.text.split("```")[1]
        if query.startswith('sql'):
            query = query[3:]
        print('First try:\n', query)
    else:
        print('No query provided (first try) - unforseen error, printing out response to help with editing this funcntion:\n', query_response.text)
    
    return query    

In [302]:
def codechat_start(question, query, schema_columns):

    # code chat model
    codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@001')
    
    # start a code chat session and give the schema for columns as the starting context:
    codechat = codechat_model.start_chat(
        context = f"""This session is trying to troubleshoot a Google BigQuery SQL query that is being writen to answer a question.
Question: {question}

BigQuery SQL Query: {query}

information_schema:
{schema_columns.to_markdown(index = False)}

Instructions:
As the user provides versions of the query and the errors returned by BigQuery, offer suggestions that fix the errors but it is important that the query still answer the original question.
"""
    )
    
    return codechat

In [303]:
def fix_query(query, max_fixes):
    
    # iteratively run query, and fix it using codechat until success (or max_fixes reached):
    fix_tries = 0
    answer = False
    while fix_tries < max_fixes:
        if not query: 
            return
        # run query:
        query_job = bq.query(query = query)
        # if errors, then generate repair query:
        if query_job.errors:
            fix_tries += 1
            
            if fix_tries == 1:
                codechat = codechat_start(question, query, schema_columns)
            
            # construct hint from error
            hint = ''
            for error in query_job.errors:
                # detect error message
                if 'message' in list(error.keys()):
                    # detect index of error location
                    if error['message'].rindex('[') and error['message'].rindex(']'):
                        begin = error['message'].rindex('[') + 1
                        end = error['message'].rindex(']')
                        # verify that it looks like an error location:
                        if end > begin and error['message'][begin:end].index(':'):
                            # retrieve the two parts of the error index: query line, query column
                            query_index = [int(q) for q in error['message'][begin:end].split(':')]
                            hint += query.split('\n')[query_index[0]-1].strip()
                            break
            
            # construct prompt to request a fix:
            fix_prompt = f"""This query:\n{query}\n\nReturns these errors:\n{query_job.errors}\n\nPlease fix it and make sure it matches the schema."""

            #if hint != '':
            #    fix_prompt += f"""Hint, the error appears to be in this line of the query:\n{hint}"""            
            
            query_response = codechat.send_message(fix_prompt)
            query_response = codechat.send_message('Respond with only the corrected query that still answers the question as a markdown code block.')
            if query_response.text.find("```") >= 0:
                query = query_response.text.split("```")[1]
                if query.startswith('sql'):
                    query = query[4:]
                print(f'Fix #{fix_tries}:\n', query)
            # response did not have a query????:
            else:
                query = ''
                print('No query in response...')

        # no error, break while loop
        else:
            break
    
    return query, query_job, fix_tries, codechat

In [304]:
def answer_question(question, query_job):

    # text generation model
    textgen_model = vertexai.language_models.TextGenerationModel.from_pretrained('text-bison')

    result = query_job.to_dataframe()
    # answer question
    question_prompt = f"""Answer the following question using the provided context.  Note that the context is a tabular result returned from a BigQuery query.  Do not repeat the question or the context when responding.

question:
{question}
context:
{result.to_markdown(index = False)}
"""
    question_response = textgen_model.predict(question_prompt, max_output_tokens = 500)
    
    return question_response.text

In [309]:
def BQ_QA(question, max_fixes = 7, schema_columns = schema_columns):
    
    query = initial_query(question, schema_columns)
    # run query:
    query_job = bq.query(query = query)
    # if errors, then generate repair query:
    if query_job.errors:
        query, query_job, fix_tries, codechat = fix_query(query, max_fixes)
    
    # respond with outcome:
    if query_job.errors:
        print(f'No answer generated after {fix_tries} tries.')
        return codechat
    else:
        question_response = answer_question(question, query_job)
        print(question_response)
        try:
            return codechat
        except:
            return None

In [310]:
session = BQ_QA(question)

First try:
 
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  EXTRACT(WEEKDAY FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND usertype = "Subscriber"
  AND birth_year < 1955;

Fix #1:
 
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  DATE(starttime) BETWEEN '2015-01-01' AND '2015-12-31'
  AND EXTRACT(WEEKDAY FROM DATE(starttime)) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND usertype = "Subscriber"
  AND birth_year < 1955;

Fix #2:
 
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-pub

In [311]:
for message in session.message_history:
    print(message.content)

This query:

SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS stations
ON CAST(trips.start_station_id AS STRING) = stations.station_id
WHERE 
  EXTRACT(WEEKDAY FROM starttime) IN (0, 6)
  AND EXTRACT(HOUR FROM starttime) BETWEEN 12 AND 17
  AND EXTRACT(YEAR FROM starttime) = 2015
  AND usertype = "Subscriber"
  AND birth_year < 1955;


Returns these errors:
[{'reason': 'invalidQuery', 'location': 'query', 'message': 'A valid date part name is required but found WEEKDAY at [7:11]'}]

Please fix it and make sure it matches the schema.
The query is not valid because the `WEEKDAY` function is not supported for the `starttime` field. The `starttime` field is a `TIMESTAMP` field, and the `WEEKDAY` function is only supported for `DATE` fields.

To fix the query, you can use the `DATE` function to convert the `starttime` field to a `DATE` field. Then, you can use the `WEEKDAY` function on the `DATE` 

### Try Another Question:

In [312]:
session = BQ_QA('How many trips during the year 2015: started in the evening, were over an hour long, and were by regular riders, who were over age 60?')

First try:
 
SELECT COUNT(*) AS num_trips
FROM bigquery-public-data.new_york.citibike_trips AS trips
JOIN bigquery-public-data.new_york.citibike_stations AS start_stations
ON CAST(trips.start_station_id AS STRING) = start_stations.station_id
JOIN bigquery-public-data.new_york.citibike_stations AS end_stations
ON CAST(trips.end_station_id AS STRING) = end_stations.station_id
WHERE EXTRACT(HOUR FROM trips.starttime) BETWEEN 18 AND 23
AND trips.tripduration > 3600
AND trips.usertype = "Subscriber"
AND EXTRACT(YEAR FROM trips.starttime) = 2015
AND trips.birth_year < 1955;

 180


### Ideas For Improvement

After a few attempts to fix, ask the CodeChat LLM to start fresh and write a new query to answer the question.  The iterate on the new query.

If the fixes are actually breaking the logic of the query try asking if the successful query answers the questions before providing the result.