## Generating SQL with JOINs

In [1]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')
sys.path.insert(1, '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src')
print(sys.path)

['/opt/conda/lib/python310.zip', '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src', '../', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages']


In [2]:
import numpy as np
import faiss
from pandas import DataFrame
from datetime import datetime
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Tool
from nl2sql_generic import Nl2sqlBq

import json
from prompts import * 

from proto.marshal.collections import repeated
from proto.marshal.collections import maps


In [3]:
# Initializing when metadata cache is already created
metadata_cache_file = "../nl2sql_src/cache_metadata/metadata_cache.json"
metadata_json_path = "../nl2sql_src/cache_metadata/metadata_cache.json"

nl2sqlbq_client = Nl2sqlBq(project_id="sl-test-project-353312", dataset_id="EY", metadata_json_path = metadata_cache_file)

In [4]:
PGPROJ = "sl-test-project-353312"
PGLOCATION = 'us-central1'
PGINSTANCE = "test-nl2sql"
PGDB = "test-db"
PGUSER = "postgres"
PGPWD = "test-nl2sql"
nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)

class Initiated


In [6]:
# from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel("gemini-1.0-pro")

table_chat = model.start_chat()
sql_chat = model.start_chat()

In [15]:
def return_table_details(table_name):
    f = open(metadata_json_path, encoding="utf-8")
    metadata_json = json.loads(f.read())
        
    table_json = metadata_json[table_name]
    columns_json = table_json["Columns"]
    columns_info = ""
    for column_name in columns_json:
        column = columns_json[column_name]            
        column_info = f"""{column["Name"]} \
                    ({column["Type"]}) : {column["Description"]}. {column["Examples"]}\n"""
        columns_info = columns_info + column_info
        
    prompt = Table_info_template.format(table_name = table_name,
                                        table_description = metadata_json[table_name]['Description'],
                                        columns_info = columns_info)
    return prompt


In [16]:
table_name_1 = "calhhs-dashboard-2015-2020-annual-data-file"
table_name_2 = "medi-cal-and-calfresh-enrollment"
table_1 = return_table_details(table_name_1)
table_2 = return_table_details(table_name_2)
sample_question = "Which five counties have the lowest number of CalFresh authorized vendors compared to CalFresh participants?"
sample_sql = """SELECT Vendor_Location,(vendor_cnt/total_participants)*100 as vendor_participants_ratio FROM
((SELECT TRIM(Vendor_Location) AS Vendor_Location,COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) as total_participants FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`  group by Vendor_Location) as participants
JOIN
(SELECT TRIM(COUNTY) AS COUNTY,count(VENDOR) as vendor_cnt FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-wic-authorized-vendors` 
group by COUNTY having COUNTY is not null) as vendors
ON UPPER(participants.Vendor_Location)=UPPER(vendors.COUNTY))
WHERE (vendor_cnt/total_participants)*100 is not null
order by vendor_participants_ratio asc limit 5;"""

# question = "Which counties have the highest and lowest ratios of providers to enrolled participants in Medi-cal?"
question = "Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?"



In [17]:
join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


In [18]:
model = GenerativeModel("gemini-1.0-pro")
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)


Zero shot prompting :
 ```sql
SELECT
  c.county,
  a.number_of_beneficiaries,
  (SELECT COUNT(*) FROM WIC_Vendors) AS TotalWICVendors
FROM calhhs_dashboard_2015_2020_annual_data_file AS a
JOIN medi_cal_and_calfresh_enrollment AS c
  ON (
    a.program = c.program || " only"
  )
WHERE
  a.program = "WIC"
  AND c.program = "CalFresh only"
ORDER BY
  a.number_of_beneficiaries
LIMIT 5;
```
One-shot prompt: 
 ```sql
SELECT
  Vendor_Location,
  (
    vendor_cnt / total_participants
  ) * 100 AS vendor_participants_ratio
FROM (
  (
    SELECT
      TRIM(Vendor_Location) AS Vendor_Location,
      COALESCE(SUM(SAFE_CAST(_Number_of_Participants_Redeemed_ AS INT64))) AS total_participants
    FROM `cdii-poc.HHS_Program_Counts.calfresh-redemption-by-county-by-participant-category-data-2010-2018`
    GROUP BY
      Vendor_Location
  ) AS participants
  JOIN (
    SELECT
      TRIM(COUNTY) AS COUNTY,
      COUNT(VENDOR) AS vendor_cnt
    FROM `cdii-poc.HHS_Program_Counts.women-infants-and-children-w

In [19]:
print(join_prompt)

You are an SQL expert at generating SQL queries from a natural language question. Given the input question, create a syntactically correct Biguery query to run.

Only use the few relevant columns required based on the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Do not use more than 10 columns in the query. Focus on the keywords indicating calculation. 
Please think step by step and always validate the reponse.
Rectify each column names by referencing them from the meta-data.

Only use the following tables meta-data: 

Table 1: 

Table Name : calhhs-dashboard-2015-2020-annual-data-file

Description: This table provides comprehensive annual data from 2015 to 2018 related to various programs and services offerred by the California Health and Human Services agency (CalHHS).  This table contains information related to the services

In [20]:
table_name = nl2sqlbq_client.table_filter(question)
table_name

['wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c']

In [21]:
print(question)

Which five counties have the lowest number of WIC authorized vendors compared to WIC participants?


In [22]:
multi_table_prompt = """
Tables context:
{table_info}

Example Question, SQL and tables containing the required info are given below
You are required to identify more than 1 table that probably contains the information requested in the question given below
Return the list of tables that may contain the information

Question : {example_question} :
SQL : {example_SQL}
Tables: {table_name_1} and {table_name_2}

Question: {question}
Tables:
"""

tab_prompt = nl2sqlbq_client.table_filter_promptonly(question)
# print(tab_prompt)

multi_prompt=multi_table_prompt.format(table_info=tab_prompt, 
                                       example_question=sample_question,
                                       example_SQL=sample_sql,
                                       table_name_1=table_name_1,
                                       table_name_2=table_name_2,
                                       question=question)
# print(multi_prompt)

In [23]:
# resp = model.generate_content(multi_prompt)

In [24]:
# print(resp.text)

In [25]:
model = GenerativeModel("gemini-1.0-pro")
multi_chat = model.start_chat()

# tab_prompt = tab_prompt + "\nContext : " + question

responses = multi_chat.send_message(multi_prompt, tools=[sql_tools])

In [26]:
responses.candidates[0].content.parts[0].text

'2010-2018-part-a-wic-redemptions-by-vendor-county-with-family-counts and wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c'

In [27]:
follow_up_prompt = """Review the question given in above context along with the table and column description and determine whether one table contains all the required information or you need to get data from another table
If two tables's information are required, then identify those tables from the tables info
What are the two tables that should be joined in the SQL query
Only mention the table name from the tables context.
"""
resp = multi_chat.send_message(follow_up_prompt)

In [28]:
res = resp.candidates[0].content.parts[0].text.split('\n*')[1:]

In [29]:
res

[]

In [30]:
res = [r.strip() for r in res]

In [31]:
res

[]

In [32]:
table_1 = return_table_details(res[0])
table_2 = return_table_details(res[1])


IndexError: list index out of range

In [None]:
join_prompt = join_prompt_template.format(table_1 = table_1,
                                          table_2 = table_2,
                                          question = question)
join_prompt_one_shot = join_prompt_template_one_shot.format(table_1 = table_1,
                                          table_2 = table_2,
                                          sample_question = sample_question,
                                          sample_sql = sample_sql,
                                          question = question)


In [None]:
model = GenerativeModel("gemini-1.0-pro")
resp = model.generate_content(join_prompt)
print("Zero shot prompting :\n", resp.text)

resp = model.generate_content(join_prompt_one_shot)
print("One-shot prompt: \n", resp.text)
