## Generating SQL with JOINs

In [1]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')
sys.path.insert(1, '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src')
print(sys.path)

['/opt/conda/lib/python310.zip', '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src', '../', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages']


In [2]:
import numpy as np
import faiss
from pandas import DataFrame
from datetime import datetime
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Tool
from nl2sql_generic import Nl2sqlBq

import json
from prompts import * 

from proto.marshal.collections import repeated
from proto.marshal.collections import maps

from google.cloud import bigquery


import logging
import time

from vertexai.language_models import CodeChatSession
from vertexai.language_models import CodeChatModel

from vertexai.language_models import CodeGenerationModel


In [None]:
PROJECT_ID = 'sql-test-project-353312'
DATASET_ID = 'zoominfo'

In [3]:
# Initializing when metadata cache is already created
metadata_cache_file = "../nl2sql_src/cache_metadata/metadata_cache.json"
metadata_json_path = "../nl2sql_src/cache_metadata/metadata_cache.json"

nl2sqlbq_client = Nl2sqlBq(project_id=PROJECT_ID, dataset_id=DATASET_ID, metadata_json_path = metadata_cache_file)

In [4]:
PGPROJ = "sl-test-project-353312"
PGLOCATION = 'us-central1'
PGINSTANCE = "test-nl2sql"
PGDB = "test-db"
PGUSER = "postgres"
PGPWD = "test-nl2sql"
nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)

class Initiated


In [6]:
# from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel("gemini-1.0-pro")

table_chat = model.start_chat()
sql_chat = model.start_chat()

In [15]:
def return_table_details(table_name):
    f = open(metadata_json_path, encoding="utf-8")
    metadata_json = json.loads(f.read())
        
    table_json = metadata_json[table_name]
    columns_json = table_json["Columns"]
    columns_info = ""
    for column_name in columns_json:
        column = columns_json[column_name]            
        column_info = f"""{column["Name"]} \
                    ({column["Type"]}) : {column["Description"]}. {column["Examples"]}\n"""
        columns_info = columns_info + column_info
        
    prompt = Table_info_template.format(table_name = table_name,
                                        table_description = metadata_json[table_name]['Description'],
                                        columns_info = columns_info)
    return prompt


In [16]:
table_name_1 = "calhhs-dashboard-2015-2020-annual-data-file"
table_name_2 = "medi-cal-and-calfresh-enrollment"
table_1 = return_table_details(table_name_1)
table_2 = return_table_details(table_name_2)
sample_question = "Which five counties have the lowest number of CalFresh authorized vendors compared to CalFresh participants?"
sample_sql = """SELECT Vendor_Location,(vendor_cnt/total_participants)*100 as vendor_participants_ratio FROM
((SELECT TRIM(Vendor_Location) AS Vendor_Location,COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) as total_participants FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`  group by Vendor_Location) as participants
JOIN
(SELECT TRIM(COUNTY) AS COUNTY,count(VENDOR) as vendor_cnt FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-wic-authorized-vendors` 
group by COUNTY having COUNTY is not null) as vendors
ON UPPER(participants.Vendor_Location)=UPPER(vendors.COUNTY))
WHERE (vendor_cnt/total_participants)*100 is not null
order by vendor_participants_ratio asc limit 5;"""

# question = "Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-cal?"
question = "Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?"



In [17]:
join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


## SQL Generation using NL2SQL Library

In [None]:
table_names=[table_name_1, table_name_2]
samples={
    "sample_question":sample_question,
    "sample_sql":sample_sql}

table_names[1]
print("*" * 30)
print("Zero-shot")
gen_sql = nl2sqlbq_client.invoke_llm(join_prompt)
print(gen_sql)

print("*" * 30)
print("One-shot")
gen_sql_os = nl2sqlbq_client.invoke_llm(join_prompt_one_shot)
print(gen_sql_os)

## Generating SQL using Gemini Pro for same prompts

In [18]:
model = GenerativeModel("gemini-1.0-pro")
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)


Zero shot prompting :
 ```sql
SELECT
  c.county,
  a.number_of_beneficiaries,
  (SELECT COUNT(*) FROM WIC_Vendors) AS TotalWICVendors
FROM calhhs_dashboard_2015_2020_annual_data_file AS a
JOIN medi_cal_and_calfresh_enrollment AS c
  ON (
    a.program = c.program || " only"
  )
WHERE
  a.program = "WIC"
  AND c.program = "CalFresh only"
ORDER BY
  a.number_of_beneficiaries
LIMIT 5;
```
One-shot prompt: 
 ```sql
SELECT
  Vendor_Location,
  (
    vendor_cnt / total_participants
  ) * 100 AS vendor_participants_ratio
FROM (
  (
    SELECT
      TRIM(Vendor_Location) AS Vendor_Location,
      COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) AS total_participants
    FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`
    GROUP BY
      Vendor_Location
  ) AS participants
  JOIN (
    SELECT
      TRIM(COUNTY) AS COUNTY,
      COUNT(VENDOR) AS vendor_cnt
    FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-w

## Multi-turn Approach

In [None]:
table_name = nl2sqlbq_client.table_filter(question)
table_name

In [None]:
# multi_table_prompt = """
# Tables context:
# {table_info}

# Example Question, SQL and tables containing the required info are given below
# You are required to identify more than 1 table that probably contains the information requested in the question given below
# Return the list of tables that may contain the information

# Question : {example_question} :
# SQL : {example_SQL}
# Tables: {table_name_1} and {table_name_2}

# Question: {question}
# Tables:
# """

tab_prompt = nl2sqlbq_client.table_filter_promptonly(question)
# print(tab_prompt)

multi_prompt=multi_table_prompt.format(table_info=tab_prompt, 
                                       example_question=sample_question,
                                       example_SQL=sample_sql,
                                       table_name_1=table_name_1,
                                       table_name_2=table_name_2,
                                       question=question)
print(multi_prompt)

In [None]:
multi_chat = model.start_chat()
responses = multi_chat.send_message(multi_prompt) #, tools=[sql_tools])

In [None]:
# follow_up_prompt = """Review the question given in above context along with the table and column description and determine whether one table contains all the required information or you need to get data from another table
# If two tables's information are required, then identify those tables from the tables info
# What are the two tables that should be joined in the SQL query
# Only mention the table name from the tables context.
# """
resp = multi_chat.send_message(follow_up_prompt)
res = resp.candidates[0].content.parts[0].text.split(' and ')

table_1 = return_table_details(res[2].split(' ')[1].strip())
table_2 = return_table_details(res[3].split(' ')[1].strip())

join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


In [None]:
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)

## Self-Correction Approach

In [None]:
TEMPERATURE = 0.3
MAX_OUTPUT_TOKENS=8192

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

bq_client = bigquery.Client(project=PROJECT_ID)

CODE_MODEL = "code-bison-32k"
MODEL_NAME = 'codechat-bison-32k'

# code_gen_model = CodeGenerationModel.from_pretrained(CODE_MODEL)
code_gen_model = CodeChatModel.from_pretrained(MODEL_NAME)

In [None]:
def generate_and_execute_sql(prompt, max_tries=5, return_all=False):
    """
    Generate an SQL query using the code_gen_model, execute it using bq_client, and rank successful queries by latency.
    
    Args:
    - prompt (str): Prompt to provide to the model for generating SQL.
    - max_tries (int): Maximum number of attempts to generate and execute SQL.
    - return_all (bool): Flag to determine whether to return all successful queries or only the fastest.
    
    Returns:
    - dict: A dictionary containing the fastest dataframe or all successful dataframes, or error messages and prompt evolution.
    """
    
    tries = 0
    error_messages = []
    prompts = [prompt]
    successful_queries = []

    # chat_session = CodeChatSession(model=model, 
    #                                temperature=TEMPERATURE, 
    #                                max_output_tokens=MAX_OUTPUT_TOKENS)
    
    chat_session = model.start_chat()
    
#     chat_session = CodeChatSession(model=code_gen_model, 
#                                    temperature=TEMPERATURE, 
#                                    max_output_tokens=MAX_OUTPUT_TOKENS)
    
    while tries < max_tries:
        logger.info(f'TRIAL: {tries+1}')
        try:
            # Predict SQL using the model
            start_time = time.time()
            # print("prommpt = ", prompt)
            
            response = chat_session.send_message(prompt)#, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS)
            # response = code_gen_model.predict(prompt, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS)
            
            generated_sql_query = response.text
            generated_sql_query = '\n'.join(generated_sql_query.split('\n')[1:-1])

            generated_sql_query = nl2sqlbq_client.case_handler_transform(generated_sql_query)
            generated_sql_query = nl2sqlbq_client.add_dataset_to_query(generated_sql_query)
            
            logger.info('-' * 50)
            logger.info(generated_sql_query)
            logger.info('-' * 50)
            # Execute SQL using BigQuery client
            df = bq_client.query(generated_sql_query).to_dataframe()
            print("Data - ", df)
            latency = time.time() - start_time
            successful_queries.append({
                "query": generated_sql_query,
                "dataframe": df,
                "latency": latency
            })
            logger.info('SUCCEEDED')
            # Evolve the prompt for success path to optimize the last successful query for latency
            if len(successful_queries) > 1:
                prompt = f"""Modify the last successful SQL query by making changes to it and optimizing it for latency. 
            ENSURE that the NEW QUERY is DIFFERENT from the previous one while prioritizing faster execution.
            Reference the tables only from the above given project and dataset
            The last successful query was:
            {successful_queries[-1]["query"]}"""
        except Exception as e:
            logger.error('FAILED')
            # Catch the error, store the message, and try again
            msg = str(e)
            error_messages.append(msg)
            if tries == 0:
                generated_sql_query = gen_sql_os
                
            # Evolve the prompt by appending the error message and asking the model to correct it
            prompt = f"""Encountered an error: {msg}. 
To address this, please generate an alternative SQL query response that avoids this specific error. 
Follow the instructions mentioned above to remediate the error. 

Modify the below SQL query to resolve the issue and ensure it is not a repetition of all previously generated queries.
{generated_sql_query}
            
Ensure the revised SQL query aligns precisely with the requirements outlined in the initial question.
Keep the table names as it is. Do not change hyphen to underscore character
Additionally, please optimize the query for latency while maintaining correctness and efficiency."""
            prompts.append(prompt)
        logger.info('=' * 100)
        tries += 1
        
    # If no successful queries
    if len(successful_queries) == 0:
        return {
            "error": "All attempts exhausted.",
            "prompts": prompts,
            "errors": error_messages
        }
    
    # Sort successful queries by latency
    successful_queries.sort(key=lambda x: x['latency'])
    
    if return_all:
        df = pd.DataFrame([(q["query"], q["dataframe"], q["latency"]) for q in successful_queries], columns=["Query", "Result", "Latency"])
        return {
            "dataframe": df
        }
    else:
        return {
            "fastest_query": successful_queries[0]["query"],
            "result": successful_queries[0]["dataframe"],
            "latency": successful_queries[0]["latency"]
        }
    

In [None]:
outputs = generate_and_execute_sql(prompt=join_prompt_one_shot, return_all=True)