## Multi-turn SQL Generation

In [1]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')
sys.path.insert(1, '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src')
print(sys.path)

['/opt/conda/lib/python310.zip', '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src', '../', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages']


In [2]:
import numpy as np
import faiss
from pandas import DataFrame
from datetime import datetime
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Tool
from nl2sql_generic import Nl2sqlBq

from proto.marshal.collections import repeated
from proto.marshal.collections import maps


In [3]:
# Initializing when metadata cache is already created
metadata_cache_file = "../nl2sql_src/cache_metadata/metadata_cache.json"
nl2sqlbq_client = Nl2sqlBq(project_id="sl-test-project-353312", dataset_id="EY", metadata_json_path = metadata_cache_file)

In [4]:
PGPROJ = "sl-test-project-353312"
PGLOCATION = 'us-central1'
PGINSTANCE = "test-nl2sql"
PGDB = "test-db"
PGUSER = "postgres"
PGPWD = "test-nl2sql"
nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)

class Initiated


In [5]:
def get_text(resp: GenerationResponse):
    part = resp.candidates[0].content.parts[0]
    try:
        text = part.text
    except:
        text = None
    return text

def prior_sql_result(question)->str:
    # Execute the SQL and return the result
    return "Test output for question : " + question


def call_api(name: str, args: str) -> str:
    if name == "prior_sql_tool":
        return prior_sql_result(args)


def recurse_proto_repeated_composite(repeated_object):
    repeated_list = []
    for item in repeated_object:
        if isinstance(item, repeated.RepeatedComposite):
            item = recurse_proto_repeated_composite(item)
            repeated_list.append(item)
        elif isinstance(item, maps.MapComposite):
            item = recurse_proto_marshal_to_dict(item)
            repeated_list.append(item)
        else:
            repeated_list.append(item)

    return repeated_list

def recurse_proto_marshal_to_dict(marshal_object):
    new_dict = {}
    for k, v in marshal_object.items():
      if not v:
        continue
      elif isinstance(v, maps.MapComposite):
          v = recurse_proto_marshal_to_dict(v)
      elif isinstance(v, repeated.RepeatedComposite):
          v = recurse_proto_repeated_composite(v)
      new_dict[k] = v

    return new_dict

previous_sql_spec = {
    "name": "prior_sql_tool",
    "description": "Provides the SQL query that is generated for the previous question",
    "parameters": {
        "type": "object",
        "properties": {
            "question": {
                "type": "string",
                "description": "The natural language question for which the SQL query was generated"
            }
        },
        "required": [
            "question"
        ]
    }
}

def get_function_name(response: GenerationResponse):
  return response.candidates[0].content.parts[0].function_call.name

def get_function_args(response: GenerationResponse) -> dict:
  return recurse_proto_marshal_to_dict(response.candidates[0].content.parts[0].function_call.args)


sql_tools = Tool.from_dict(
    {
        "function_declarations":[previous_sql_spec]
    }
)

In [6]:
# from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel("gemini-1.0-pro")

table_chat = model.start_chat()
sql_chat = model.start_chat()

In [7]:
def get_chat_response(chat_model, prompt) -> str:
    responses = chat_model.send_message(prompt, stream=True)
    output = []
    for response in responses:
        output.append(response.candidates[0].content.parts[0].text)
    return "".join(output)

In [8]:
questions = ["How many people are enrolled in CalFresh?",
             "How many of them live in Los Angeles County?"
            ]

# questions = ["How many Black individuals are served across CalHHS programs?",
#              "What is the breakdown by program?",
#              "Has this changed over time?",
#              "Change over time by program?"
#             ]

question = questions[0]

In [9]:
table_prompt = nl2sqlbq_client.table_filter_promptonly("Table identification initiation")
# print(table_prompt)
table_chat.send_message(table_prompt)




candidates {
  content {
    role: "model"
    parts {
      text: "wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c"
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    probability_score: 0.2916865348815918
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.16344544291496277
  }
  safety_ratings {
    category: HARM_CATEGORY_DANGEROUS_CONTENT
    probability: NEGLIGIBLE
    probability_score: 0.2040245532989502
    severity: HARM_SEVERITY_LOW
    severity_score: 0.21157968044281006
  }
  safety_ratings {
    category: HARM_CATEGORY_HARASSMENT
    probability: NEGLIGIBLE
    probability_score: 0.2639787197113037
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.1445106565952301
  }
  safety_ratings {
    category: HARM_CATEGORY_SEXUALLY_EXPLICIT
    probability: NEGLIGIBLE
    probability_score: 0.1489114761352539
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.0

In [10]:
q_prompt_template = """Using the context in the chat history, identify the table name that is most probable to contain the data requested for the question given below.

Question: {question}
"""
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

Table identified for queestion :' How many people are enrolled in CalFresh? ' is:  medi-cal-and-calfresh-enrollment
Table identified for queestion :' How many of them live in Los Angeles County? ' is:  medi-cal-and-calfresh-enrollment


In [11]:
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

    try:
        previous_question_sql = sql_chat.history[-1]
    except:
        previous_question_sql = ""

    sql_prompt = nl2sqlbq_client.generate_sql_few_shot_promptonly(question, table_name=table_identified, prev_sql=previous_question_sql)
    # print(sql_prompt)
    sql_gen = get_chat_response(sql_chat, sql_prompt)
    print("Generated SQL for question : ", question, " is \n ", sql_gen)




Table identified for queestion :' How many people are enrolled in CalFresh? ' is:  medi-cal-and-calfresh-enrollment
Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Generated SQL for question :  How many people are enrolled in CalFresh?  is 
  ```sql
SELECT 
  COALESCE(SUM(SAFE_CAST(Number_of_Beneficiaries AS INT64)), 0) AS total_enrolled_in_calfresh
FROM 
  `bigquery-public-data.dhcs_medi_cal.medi_cal_and_calfresh_enrollment`
WHERE 
  Program = 'CalFresh Only' OR Program = 'CalFresh and Medi-Cal';
```
Table identified for queestion :' How many of them live in Los Angeles County? ' is:  medi-cal-and-calfresh-enrollment
Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Generated SQL for question :  How many of them live in Los Angeles County?  is 
  ```sql
SELECT 
  COALESCE(SUM(SAFE_CAST(Number_of_Beneficiaries AS INT64)), 0) AS total_enrolled_in_calfresh
FROM 
  `bigquery-public-data.dhcs_medi

In [12]:
question = "How has participation in CalFresh changed since 2015?"
table_identified = get_chat_response(table_chat, q_prompt)
previous_question_sql=""
sql_prompt = nl2sqlbq_client.generate_sql_few_shot_promptonly(question, table_name=table_identified, prev_sql=previous_question_sql)

sql_chat = model.start_chat()

sql_resp = sql_chat.send_message(sql_prompt, tools=[sql_tools])
txt = get_text(sql_resp)
if txt:
    print("Response from chat", txt)
else:
    fname = get_function_name(sql_resp)
    # fname = sql_resp.candidates[0].content.parts[0].function_call.name
    print("Function to call", fname)
    func_args = get_function_args(sql_resp)
    # fargs = sql_resp.candidates[0].content.parts[0].function_call.args
    print("function call arguments =", func_args)
    print(call_api(fname, func_args['question']))

Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Function to call prior_sql_tool
function call arguments = {'question': 'How has participation in CalFresh changed since 2015?'}
Test output for question : How has participation in CalFresh changed since 2015?


In [13]:
sql_chat.history

[role: "user"
 parts {
   text: "\nOnly use the following tables meta-data:\n\n```\nTable Name : medi-cal-and-calfresh-enrollment\n\nDescription: This table provides information about the enrollment of individuals in Medi-Cal and CalFresh programs in different counties. This information in this table can be used to calculate the ratio of beneficiaries and service providers for a particular program\n\nThis table has the following columns : \nEligiblity_Date                     (STRING) : . \nCounty                     (STRING) : . \nProgram                     (STRING) : . It contains values : \"CalFresh Only\", \"CalFresh only\", \"CalFresh and Medi-Cal\".\nNumber_of_Beneficiaries                     (INTEGER) : . \nANNOTATION_CODE                     (INTEGER) : . \nCOUNT_ANNOTATION_DESC                     (STRING) : . It contains values : \"Cell Suppressed For Small Numbers\".\n\n\n\n```\n\nYou are an SQL expert at generating SQL queries from a natural language question. Given the i

## Join related

In [14]:
import json
from prompts import * 
metadata_json_path = "../nl2sql_src/cache_metadata/metadata_cache.json"


In [15]:
def return_table_details(table_name):
    f = open(metadata_json_path, encoding="utf-8")
    metadata_json = json.loads(f.read())
        
    table_json = metadata_json[table_name]
    columns_json = table_json["Columns"]
    columns_info = ""
    for column_name in columns_json:
        column = columns_json[column_name]            
        column_info = f"""{column["Name"]} \
                    ({column["Type"]}) : {column["Description"]}. {column["Examples"]}\n"""
        columns_info = columns_info + column_info
        
    prompt = Table_info_template.format(table_name = table_name,
                                        table_description = metadata_json[table_name]['Description'],
                                        columns_info = columns_info)
    return prompt


In [16]:
table_name_1 = "calhhs-dashboard-2015-2020-annual-data-file"
table_name_2 = "medi-cal-and-calfresh-enrollment"
table_1 = return_table_details(table_name_1)
table_2 = return_table_details(table_name_2)
sample_question = "Which five counties have the lowest number of CalFresh authorized vendors compared to CalFresh participants?"
sample_sql = """SELECT Vendor_Location,(vendor_cnt/total_participants)*100 as vendor_participants_ratio FROM
((SELECT TRIM(Vendor_Location) AS Vendor_Location,COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) as total_participants FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`  group by Vendor_Location) as participants
JOIN
(SELECT TRIM(COUNTY) AS COUNTY,count(VENDOR) as vendor_cnt FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-wic-authorized-vendors` 
group by COUNTY having COUNTY is not null) as vendors
ON UPPER(participants.Vendor_Location)=UPPER(vendors.COUNTY))
WHERE (vendor_cnt/total_participants)*100 is not null
order by vendor_participants_ratio asc limit 5;"""

# question = "Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-cal?"
question = "Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?"



In [17]:
join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


In [18]:
model = GenerativeModel("gemini-1.0-pro")
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)


Zero shot prompting :
 ```sql
SELECT
  c.county,
  a.number_of_beneficiaries,
  (SELECT COUNT(*) FROM WIC_Vendors) AS TotalWICVendors
FROM calhhs_dashboard_2015_2020_annual_data_file AS a
JOIN medi_cal_and_calfresh_enrollment AS c
  ON (
    a.program = c.program || " only"
  )
WHERE
  a.program = "WIC"
  AND c.program = "CalFresh only"
ORDER BY
  a.number_of_beneficiaries
LIMIT 5;
```
One-shot prompt: 
 ```sql
SELECT
  Vendor_Location,
  (
    vendor_cnt / total_participants
  ) * 100 AS vendor_participants_ratio
FROM (
  (
    SELECT
      TRIM(Vendor_Location) AS Vendor_Location,
      COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) AS total_participants
    FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`
    GROUP BY
      Vendor_Location
  ) AS participants
  JOIN (
    SELECT
      TRIM(COUNTY) AS COUNTY,
      COUNT(VENDOR) AS vendor_cnt
    FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-w

In [19]:
print(join_prompt)

You are an SQL expert at generating SQL queries from a natural language question. Given the input question, create a syntactically correct Biguery query to run.

Only use the few relevant columns required based on the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Do not use more than 10 columns in the query. Focus on the keywords indicating calculation. 
Please think step by step and always validate the reponse.
Rectify each column names by referencing them from the meta-data.

Only use the following tables meta-data: 

Table 1: 

Table Name : calhhs-dashboard-2015-2020-annual-data-file

Description: This table provides comprehensive annual data from 2015 to 2018 related to various programs and services offerred by the California Health and Human Services agency (CalHHS).  This table contains information related to the services

In [20]:
table_name = nl2sqlbq_client.table_filter(question)
table_name

['wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c']

In [21]:
print(question)

Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?


In [22]:
multi_table_prompt = """
Tables context:
{table_info}

Example Question, SQL and tables containing the required info are given below
You are required to identify more than 1 table that probably contains the information requested in the question given below
Return the list of tables that may contain the information

Question : {example_question} :
SQL : {example_SQL}
Tables: {table_name_1} and {table_name_2}

Question: {question}
Tables:
"""

tab_prompt = nl2sqlbq_client.table_filter_promptonly(question)
# print(tab_prompt)

multi_prompt=multi_table_prompt.format(table_info=tab_prompt, 
                                       example_question=sample_question,
                                       example_SQL=sample_sql,
                                       table_name_1=table_name_1,
                                       table_name_2=table_name_2,
                                       question=question)
# print(multi_prompt)

In [23]:
# resp = model.generate_content(multi_prompt)

In [24]:
# print(resp.text)

In [25]:
model = GenerativeModel("gemini-1.0-pro")
multi_chat = model.start_chat()

# tab_prompt = tab_prompt + "\nContext : " + question

responses = multi_chat.send_message(multi_prompt, tools=[sql_tools])

In [26]:
responses.candidates[0].content.parts[0].text

'2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts and wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c'

In [27]:
follow_up_prompt = """Review the question given in above context along with the table and column description and determine whether one table contains all the required information or you need to get data from another table
If two tables's information are required, then identify those tables from the tables info
What are the two tables that should be joined in the SQL query
Only mention the table name from the tables context.
"""
resp = multi_chat.send_message(follow_up_prompt)

In [28]:
res = resp.candidates[0].content.parts[0].text.split('\n*')[1:]

In [29]:
res

[]

In [30]:
res = [r.strip() for r in res]

In [31]:
res

[]

In [32]:
table_1 = return_table_details(res[0])
table_2 = return_table_details(res[1])


IndexError: list index out of range

In [None]:
join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


In [None]:
model = GenerativeModel("gemini-1.0-pro")
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)
