In [8]:
import os
# Import Google Cloud and Vertex AI libraries
from google.cloud import bigquery
import vertexai
import vertexai.language_models
import vertexai.preview.generative_models

# Set environment variables for Google Cloud credentials and TensorFlow logging level
GOOGLE_APPLICATION_CREDENTIALS = "/Users/zacharynguyen/Documents/GitHub/2024/Applied-Generative-AI/IAM/zacharynguyen-genai-656c475b142a.json"
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = GOOGLE_APPLICATION_CREDENTIALS
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Define Google Cloud project and region information
PROJECT_ID = 'zacharynguyen-genai'
REGION = 'us-central1'
EXPERIMENT = 'sdoh_cdc_wonder_natality'
SERIES = 'applied-genai'

# Initialize Vertex AI
vertexai.init(project=PROJECT_ID, location=REGION)

# Initialize the BigQuery client
bq = bigquery.Client(project=PROJECT_ID)

# Initialize Vertex AI models
textgen_model = vertexai.language_models.TextGenerationModel.from_pretrained('text-bison@002')
codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison@002')
gemini_model = vertexai.preview.generative_models.GenerativeModel("gemini-pro")

# BigQuery constants for querying
BQ_PROJECT = 'bigquery-public-data'
BQ_DATASET = 'sdoh_cdc_wonder_natality'
BQ_TABLES = ['county_natality', 'county_natality_by_mother_race', 'county_natality_by_father_race']

# Construct and execute a BigQuery query to retrieve schema columns
query = f"""
    SELECT * EXCEPT(field_path, collation_name, rounding_mode)
    FROM `{BQ_PROJECT}.{BQ_DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in BQ_TABLES])})
"""
schema_columns = bq.query(query=query).to_dataframe()

# Define the question for the model
question = "Which mother's single race category reports the highest average number of births and what is that number?"


In [9]:
def initial_query(question, schema_columns):
    
    # code generation model
    codegen_model = vertexai.language_models.CodeGenerationModel.from_pretrained('code-bison@002')
    
    # initial request for query:
    context_prompt = f"""
context (BigQuery Table Schema):
{schema_columns.to_markdown(index = False)}

Write a query for Google BigQuery using fully qualified table names to answer this question:
{question}
"""

    context_query = codegen_model.predict(context_prompt, max_output_tokens = 256)
    
    # extract query from response
    if context_query.text.find("```") >= 0:
        context_query = context_query.text.split("```")[1]
        if context_query.startswith('sql'):
            context_query = context_query[3:]
        print('Initial Query:\n', context_query)
    else:
        print('No query provided (first try) - unforseen error, printing out response to help with editing this funcntion:\n', query_response.text)
    
    return context_query  

In [10]:
initial_query(question,schema_columns)

Initial Query:
 
SELECT
  Mothers_Single_Race,
  AVG(Births) AS average_births
FROM
  `bigquery-public-data.sdoh_cdc_wonder_natality.county_natality_by_mother_race`
GROUP BY
  Mothers_Single_Race
ORDER BY
  average_births DESC
LIMIT 1;


'\nSELECT\n  Mothers_Single_Race,\n  AVG(Births) AS average_births\nFROM\n  `bigquery-public-data.sdoh_cdc_wonder_natality.county_natality_by_mother_race`\nGROUP BY\n  Mothers_Single_Race\nORDER BY\n  average_births DESC\nLIMIT 1;\n'

In [11]:
def codechat_start(question, query, schema_columns):

    # code chat model
    codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@002')
    
    # start a code chat session and give the schema for columns as the starting context:
    codechat = codechat_model.start_chat(
        context = f"""
The BigQuery Environment has tables defined by the follow schema:
{schema_columns.to_markdown(index = False)}

This session is trying to troubleshoot a Google BigQuery SQL query that is being writen to answer the question:
{question}

BigQuery SQL query that needs to be fixed:
{query}

Instructions:
As the user provides versions of the query and the errors returned by BigQuery, offer suggestions that fix the errors but it is important that the query still answer the original question.
"""
    )
    
    return codechat

In [12]:
def fix_query(query, max_fixes):
    
    # iteratively run query, and fix it using codechat until success (or max_fixes reached):
    fix_tries = 0
    answer = False
    while fix_tries < max_fixes:
        if not query: 
            return
        # run query:
        query_job = bq.query(query = query)
        # if errors, then generate repair query:
        if query_job.errors:
            fix_tries += 1
            
            if fix_tries == 1:
                codechat = codechat_start(question, query, schema_columns)
            
            # construct hint from error
            hint = ''
            for error in query_job.errors:
                # detect error message
                if 'message' in list(error.keys()):
                    # detect index of error location
                    if error['message'].rindex('[') and error['message'].rindex(']'):
                        begin = error['message'].rindex('[') + 1
                        end = error['message'].rindex(']')
                        # verify that it looks like an error location:
                        if end > begin and error['message'][begin:end].index(':'):
                            # retrieve the two parts of the error index: query line, query column
                            query_index = [int(q) for q in error['message'][begin:end].split(':')]
                            hint += query.split('\n')[query_index[0]-1].strip()
                            break
            
            # construct prompt to request a fix:
            fix_prompt = f"""This query:\n{query}\n\nReturns these errors:\n{query_job.errors}"""

            if hint != '':
                fix_prompt += f"""\n\nHint, the error appears to be in this line of the query:\n{hint}"""            
            
            query_response = codechat.send_message(fix_prompt)
            query_response = codechat.send_message('Respond with only the corrected query that still answers the question as a markdown code block.')
            if query_response.text.find("```") >= 0:
                query = query_response.text.split("```")[1]
                if query.startswith('sql'):
                    query = query[4:]
                print(f'Fix #{fix_tries}:\n', query)
            # response did not have a query????:
            else:
                query = ''
                print('No query in response...')

        # no error, break while loop
        else:
            break
    
    return query, query_job, fix_tries, codechat

In [41]:
def answer_question(question, query_job):

    # text generation model
    gemini_model = vertexai.preview.generative_models.GenerativeModel("gemini-pro")

    # answer question
    result = query_job.to_dataframe()
    question_prompt = f"""
Please derive insights from: {question}
Utilize the statistics from the BigQuery table relevant to this inquiry. Emphasize crucial discoveries and their implications for strategic actions, ensuring to mention specific statistics where possible. Avoid repeating the question or detailing the dataset's context.

Use this data:
{result.to_markdown(index = False)}
    """

    question_response = gemini_model.generate_content(question_prompt)
    
    return question_response.text

In [39]:
def BQ_QA(question, max_fixes = 10, schema_columns = schema_columns):
    
    # generate query
    query = initial_query(question, schema_columns)
    
    # run query:
    query_job = bq.query(query = query)
    # if errors, then generate repair query:
    if query_job.errors:
        print('found errors')
        query, query_job, fix_tries, codechat = fix_query(query, max_fixes)
    
    # respond with outcome:
    if query_job.errors:
        print(f'No answer generated after {fix_tries} tries.')
        return codechat
    else:
        question_response = answer_question(question, query_job)
        print(question_response)
        try:
            return codechat
        except:
            return None

In [24]:
session = BQ_QA("Which mother's single race category reports the highest average number of births and what is that number?")

Initial Query:
 
SELECT
  Mothers_Single_Race,
  AVG(Births) AS average_births
FROM
  `bigquery-public-data.sdoh_cdc_wonder_natality.county_natality_by_mother_race`
GROUP BY
  Mothers_Single_Race
ORDER BY
  average_births DESC
LIMIT 1;
**Insight:**

White mothers have the highest average number of births among single-race mothers, with an average of 4528.37 births.

**Implications:**

* This finding suggests that White mothers may be more likely to have children out of wedlock or to have more children overall.
* This information could be used to develop targeted interventions aimed at reducing unintended pregnancies or supporting single mothers.
* Additionally, it could help policymakers understand the needs of single mothers and families and develop policies that address their unique challenges.


In [25]:
session = BQ_QA("How does the average gestational age at birth vary by the mother's single race?")

Initial Query:
 
SELECT
  county_natality_by_mother_race.Year,
  county_natality_by_mother_race.Mothers_Single_Race,
  AVG(county_natality_by_mother_race.Ave_OE_Gestational_Age_Wks) AS average_gestational_age
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality_by_mother_race
GROUP BY
  county_natality_by_mother_race.Year,
  county_natality_by_mother_race.Mothers_Single_Race
ORDER BY
  county_natality_by_mother_race.Year,
  average_gestational_age DESC;
**Key Insights:**

* **Racial Disparities:** White mothers have consistently higher average gestational ages than other racial groups, with a difference of up to 0.4 weeks (approximately 2.8 days) compared to Black or African American mothers.
* **Time Trend:** The average gestational age has been declining slightly over time for all racial groups, with the exception of Native Hawaiian or Other Pacific Islander mothers.
* **Lowest Gestational Ages:** Black or African American mothers have the lowest average gestational a

In [27]:
session = BQ_QA ("How does the average pre-pregnancy BMI vary by county?")

Initial Query:
 
SELECT
  county_natality.County_of_Residence,
  AVG(county_natality.Ave_Pre_pregnancy_BMI) AS average_pre_pregnancy_bmi
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality
GROUP BY
  county_natality.County_of_Residence
ORDER BY
  average_pre_pregnancy_bmi DESC;
**Crucial Insights:**

* **County-level Variations:** The average pre-pregnancy BMI varies widely across counties, ranging from 23.8967 in New York County, NY, to 29.3533 in Sumter County, SC.
* **Top 5 Counties with Highest BMIs:**
    * Sumter County, SC: 29.3533
    * Oswego County, NY: 29.1667
    * St. Lawrence County, NY: 29.0667
    * Hardin County, KY: 29.0133
    * Merced County, CA: 28.9467
* **Bottom 5 Counties with Lowest BMIs:**
    * San Francisco County, CA: 24.0067
    * Marin County, CA: 24.6867
    * Arlington County, VA: 24.3567
    * Williamson County, TN: 24.6833
    * Boulder County, CO: 24.8567
* **State-level Patterns:** Some states have consistently higher or lower avera

In [40]:
session = BQ_QA("Is there a correlation between pre-pregnancy BMI and birth weight?")

Initial Query:
 
SELECT
  county_natality.County_of_Residence,
  county_natality.Ave_Pre_pregnancy_BMI,
  county_natality.Ave_Birth_Weight_gms,
  CORR(county_natality.Ave_Pre_pregnancy_BMI, county_natality.Ave_Birth_Weight_gms) AS correlation
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality;
found errors
Fix #1:
 SELECT
  Mothers_Single_Race,
  AVG(Births) AS average_births
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality_by_mother_race
GROUP BY
  Mothers_Single_Race
ORDER BY
  average_births DESC
LIMIT 1;
**Correlation:**

There is a positive correlation between pre-pregnancy BMI and birth weight. This suggests that mothers with a higher pre-pregnancy BMI tend to have babies with higher birth weights.

**Crucial Discoveries:**

* The average birth weight for mothers in the White race category is 4528.37. This statistic can serve as a benchmark for comparing the birth weights of other race/ethnic groups or subgroups.
* The positive correlation be

# Zachary Version

In [31]:
def start_code_chat(question, query, schema_columns):
    """
    Initiates a code chat session for SQL query troubleshooting.
    """
    codechat_model = vertexai.language_models.CodeChatModel.from_pretrained('codechat-bison@002')
    schema_md = schema_columns.to_markdown(index=False)
    context = f"""
    The BigQuery Environment has tables defined by the following schema:
    {schema_md}

    Troubleshoot a Google BigQuery SQL query for:
    {question}

    Query needing attention:
    {query}

    Instructions: Provide fixes for query errors ensuring the query still answers the original question.
    """
    
    return codechat_model.start_chat(context=context)

def iteratively_fix_query(question, query, schema_columns, max_fixes=10):
    """
    Attempts to fix the provided query using code chat, up to max_fixes times.
    """
    fix_tries = 0
    codechat_session = None

    while fix_tries < max_fixes and query:
        query_job = bq.query(query=query)
        if query_job.errors:
            fix_tries += 1
            if fix_tries == 1:
                codechat_session = start_code_chat(question, query, schema_columns)

            hint = extract_hint_from_errors(query, query_job.errors)
            fix_prompt = construct_fix_prompt(query, query_job.errors, hint)
            query, has_query = attempt_query_fix(codechat_session, fix_prompt, fix_tries)
            if not has_query:
                break
        else:
            break
    
    return query, query_job, fix_tries, codechat_session

def extract_hint_from_errors(query, errors):
    """
    Extracts a hint from query errors for better troubleshooting.
    """
    hint = ''
    for error in errors:
        message = error.get('message', '')
        begin, end = message.find('[') + 1, message.find(']')
        if begin > 0 and end > begin:
            query_line, _ = map(int, message[begin:end].split(':'))
            hint = query.split('\n')[query_line-1].strip()
            break
    return hint

def construct_fix_prompt(query, errors, hint):
    """
    Constructs a prompt for fixing the query based on errors and a hint.
    """
    error_messages = '; '.join([error.get('message', 'Unknown error') for error in errors])
    fix_prompt = f"This query:\n{query}\n\nReturned errors:\n{error_messages}"
    if hint:
        fix_prompt += f"\n\nHint, possible issue in this line:\n{hint}"
    return fix_prompt

def attempt_query_fix(codechat_session, fix_prompt, fix_tries):
    """
    Attempts to fix the query based on the code chat session's feedback.
    """
    query_response = codechat_session.send_message(fix_prompt)
    response_text = query_response.text
    if "```" in response_text:
        corrected_query = response_text.split("```")[1]
        if corrected_query.startswith('sql'):
            corrected_query = corrected_query[4:]
        print(f'Fix #{fix_tries} applied:\n', corrected_query)
        return corrected_query, True
    else:
        print('No query in response.')
        return '', False

def generate_insightful_answer(question, query_job):
    """
    Generates an insightful answer from the query job's result.
    """
    gemini_model = vertexai.preview.generative_models.GenerativeModel("gemini-pro")
    result_md = query_job.to_dataframe().to_markdown(index=False)
    prompt = f"""
    Please derive insights from: {question}
    Utilize the statistics from the BigQuery table relevant to this inquiry, emphasizing crucial discoveries and strategic implications. Mention specific statistics where possible.

    Data:
    {result_md}
    """
    return gemini_model.generate_content(prompt).text



In [35]:
def bq_qa_process(question, schema_columns=schema_columns, max_fixes=10):
    """
    The main process for generating a query, fixing it if necessary, and providing an answer.
    """
    query = initial_query(question, schema_columns)
    query, query_job, fix_tries, codechat_session = iteratively_fix_query(question, query, schema_columns, max_fixes)

    if not query_job.errors:
        answer = generate_insightful_answer(question, query_job)
        print(answer)
    else:
        print(f'Could not generate an answer after {fix_tries} attempts.')

    return codechat_session if 'codechat_session' in locals() else None


In [37]:
session = bq_qa_process("Is there a correlation between pre-pregnancy BMI and birth weight?")

Initial Query:
 
SELECT
  county_natality.County_of_Residence,
  county_natality.Ave_Pre_pregnancy_BMI,
  county_natality.Ave_Birth_Weight_gms,
  CORR(county_natality.Ave_Pre_pregnancy_BMI, county_natality.Ave_Birth_Weight_gms) AS correlation
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality;
Fix #1 applied:
 SELECT
  county_natality.County_of_Residence,
  AVG(county_natality.Ave_Pre_pregnancy_BMI) AS average_pre_pregnancy_bmi,
  AVG(county_natality.Ave_Birth_Weight_gms) AS average_birth_weight,
  CORR(
    AVG(county_natality.Ave_Pre_pregnancy_BMI),
    AVG(county_natality.Ave_Birth_Weight_gms)
  ) AS correlation
FROM
  bigquery-public-data.sdoh_cdc_wonder_natality.county_natality
GROUP BY
  county_natality.County_of_Residence;
Fix #2 applied:
 SELECT
  county_natality.County_of_Residence,
  AVG(county_natality.Ave_Pre_pregnancy_BMI) AS average_pre_pregnancy_bmi,
  AVG(county_natality.Ave_Birth_Weight_gms) AS average_birth_weight,
  CORR(
    county_natality.Ave_Pre