## Multi-turn SQL Generation

In [1]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')
sys.path.insert(1, '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src')
print(sys.path)

['/opt/conda/lib/python310.zip', '/home/jupyter/git_repo/nl2sql-generic/nl2sql_src', '../', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages']


In [2]:
import numpy as np
import faiss
from pandas import DataFrame
from datetime import datetime
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Tool
from nl2sql_generic import Nl2sqlBq

from proto.marshal.collections import repeated
from proto.marshal.collections import maps


In [None]:
PROJECT_ID = 'sl-test-project-353312'
DATASET_ID = 'zoominfo'

In [3]:
# Initializing when metadata cache is already created
metadata_cache_file = "../nl2sql_src/cache_metadata/metadata_cache.json"
nl2sqlbq_client = Nl2sqlBq(project_id=PROJECT_ID, dataset_id=DATASET_ID, metadata_json_path = metadata_cache_file)

In [4]:
PGPROJ = "sl-test-project-353312"
PGLOCATION = 'us-central1'
PGINSTANCE = "test-nl2sql"
PGDB = "test-db"
PGUSER = "postgres"
PGPWD = "test-nl2sql"
nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)

class Initiated


In [5]:
def get_text(resp: GenerationResponse):
    part = resp.candidates[0].content.parts[0]
    try:
        text = part.text
    except:
        text = None
    return text

def prior_sql_result(question)->str:
    # Execute the SQL and return the result
    return "Test output for question : " + question


def call_api(name: str, args: str) -> str:
    if name == "prior_sql_tool":
        return prior_sql_result(args)


def recurse_proto_repeated_composite(repeated_object):
    repeated_list = []
    for item in repeated_object:
        if isinstance(item, repeated.RepeatedComposite):
            item = recurse_proto_repeated_composite(item)
            repeated_list.append(item)
        elif isinstance(item, maps.MapComposite):
            item = recurse_proto_marshal_to_dict(item)
            repeated_list.append(item)
        else:
            repeated_list.append(item)

    return repeated_list

def recurse_proto_marshal_to_dict(marshal_object):
    new_dict = {}
    for k, v in marshal_object.items():
      if not v:
        continue
      elif isinstance(v, maps.MapComposite):
          v = recurse_proto_marshal_to_dict(v)
      elif isinstance(v, repeated.RepeatedComposite):
          v = recurse_proto_repeated_composite(v)
      new_dict[k] = v

    return new_dict

previous_sql_spec = {
    "name": "prior_sql_tool",
    "description": "Provides the SQL query that is generated for the previous question",
    "parameters": {
        "type": "object",
        "properties": {
            "question": {
                "type": "string",
                "description": "The natural language question for which the SQL query was generated"
            }
        },
        "required": [
            "question"
        ]
    }
}

def get_function_name(response: GenerationResponse):
  return response.candidates[0].content.parts[0].function_call.name

def get_function_args(response: GenerationResponse) -> dict:
  return recurse_proto_marshal_to_dict(response.candidates[0].content.parts[0].function_call.args)


sql_tools = Tool.from_dict(
    {
        "function_declarations":[previous_sql_spec]
    }
)

In [6]:
# from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel("gemini-1.0-pro")

table_chat = model.start_chat()
sql_chat = model.start_chat()

In [7]:
def get_chat_response(chat_model, prompt) -> str:
    responses = chat_model.send_message(prompt, stream=True)
    output = []
    for response in responses:
        output.append(response.candidates[0].content.parts[0].text)
    return "".join(output)

In [8]:
questions = ["How many people are enrolled in CalFresh?",
             "How many of them live in Los Angeles County?"
            ]

# questions = ["How many Black individuals are served across CalHHS programs?",
#              "What is the breakdown by program?",
#              "Has this changed over time?",
#              "Change over time by program?"
#             ]

question = questions[0]

In [9]:
table_prompt = nl2sqlbq_client.table_filter_promptonly("Table identification initiation")
# print(table_prompt)
table_chat.send_message(table_prompt)




candidates {
  content {
    role: "model"
    parts {
      text: "wic-redemptions-by-vendor-county-with-family-counts-2021-2022-part-c"
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    probability_score: 0.2916865348815918
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.16344544291496277
  }
  safety_ratings {
    category: HARM_CATEGORY_DANGEROUS_CONTENT
    probability: NEGLIGIBLE
    probability_score: 0.2040245532989502
    severity: HARM_SEVERITY_LOW
    severity_score: 0.21157968044281006
  }
  safety_ratings {
    category: HARM_CATEGORY_HARASSMENT
    probability: NEGLIGIBLE
    probability_score: 0.2639787197113037
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.1445106565952301
  }
  safety_ratings {
    category: HARM_CATEGORY_SEXUALLY_EXPLICIT
    probability: NEGLIGIBLE
    probability_score: 0.1489114761352539
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.0

In [10]:
q_prompt_template = """Using the context in the chat history, identify the table name that is most probable to contain the data requested for the question given below.

Question: {question}
"""
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

Table identified for queestion :' How many people are enrolled in CalFresh? ' is:  medi-cal-and-calfresh-enrollment
Table identified for queestion :' How many of them live in Los Angeles County? ' is:  medi-cal-and-calfresh-enrollment


In [11]:
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

    try:
        previous_question_sql = sql_chat.history[-1]
    except:
        previous_question_sql = ""

    sql_prompt = nl2sqlbq_client.generate_sql_few_shot_promptonly(question, table_name=table_identified, prev_sql=previous_question_sql)
    # print(sql_prompt)
    sql_gen = get_chat_response(sql_chat, sql_prompt)
    print("Generated SQL for question : ", question, " is \n ", sql_gen)




Table identified for queestion :' How many people are enrolled in CalFresh? ' is:  medi-cal-and-calfresh-enrollment
Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Generated SQL for question :  How many people are enrolled in CalFresh?  is 
  ```sql
SELECT 
  COALESCE(SUM(SAFE_CAST(Number_of_Beneficiaries AS INT64)), 0) AS total_enrolled_in_calfresh
FROM 
  `bigquery-public-data.dhcs_medi_cal.medi_cal_and_calfresh_enrollment`
WHERE 
  Program = 'CalFresh Only' OR Program = 'CalFresh and Medi-Cal';
```
Table identified for queestion :' How many of them live in Los Angeles County? ' is:  medi-cal-and-calfresh-enrollment
Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Generated SQL for question :  How many of them live in Los Angeles County?  is 
  ```sql
SELECT 
  COALESCE(SUM(SAFE_CAST(Number_of_Beneficiaries AS INT64)), 0) AS total_enrolled_in_calfresh
FROM 
  `bigquery-public-data.dhcs_medi

In [12]:
question = "How has participation in CalFresh changed since 2015?"
table_identified = get_chat_response(table_chat, q_prompt)
previous_question_sql=""
sql_prompt = nl2sqlbq_client.generate_sql_few_shot_promptonly(question, table_name=table_identified, prev_sql=previous_question_sql)

sql_chat = model.start_chat()

sql_resp = sql_chat.send_message(sql_prompt, tools=[sql_tools])
txt = get_text(sql_resp)
if txt:
    print("Response from chat", txt)
else:
    fname = get_function_name(sql_resp)
    # fname = sql_resp.candidates[0].content.parts[0].function_call.name
    print("Function to call", fname)
    func_args = get_function_args(sql_resp)
    # fargs = sql_resp.candidates[0].content.parts[0].function_call.args
    print("function call arguments =", func_args)
    print(call_api(fname, func_args['question']))

Trying to read the index file ../../nl2sql-generic/nl2sql_src/cache_metadata/saved_index_pgdata
Function to call prior_sql_tool
function call arguments = {'question': 'How has participation in CalFresh changed since 2015?'}
Test output for question : How has participation in CalFresh changed since 2015?


In [13]:
sql_chat.history

[role: "user"
 parts {
   text: "\nOnly use the following tables meta-data:\n\n```\nTable Name : medi-cal-and-calfresh-enrollment\n\nDescription: This table provides information about the enrollment of individuals in Medi-Cal and CalFresh programs in different counties. This information in this table can be used to calculate the ratio of beneficiaries and service providers for a particular program\n\nThis table has the following columns : \nEligiblity_Date                     (STRING) : . \nCounty                     (STRING) : . \nProgram                     (STRING) : . It contains values : \"CalFresh Only\", \"CalFresh only\", \"CalFresh and Medi-Cal\".\nNumber_of_Beneficiaries                     (INTEGER) : . \nANNOTATION_CODE                     (INTEGER) : . \nCOUNT_ANNOTATION_DESC                     (STRING) : . It contains values : \"Cell Suppressed For Small Numbers\".\n\n\n\n```\n\nYou are an SQL expert at generating SQL queries from a natural language question. Given the i