## Multi-turn SQL Generation

In [None]:
%load_ext autoreload
%autoreload 2

import sys  
sys.path.insert(1, '../')

In [None]:
import numpy as np
import faiss
from pandas import DataFrame
from datetime import datetime
from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Tool
from nl2sql_generic import Nl2sqlBq

In [None]:
# Initializing when metadata cache is already created
metadata_cache_file = "../nl2sql_src/cache_metadata/metadata_cache.json"
nl2sqlbq_client = Nl2sqlBq(project_id="sl-test-project-353312", dataset_id="EY", metadata_json_path = metadata_cache_file)

In [None]:
PGPROJ = "sl-test-project-353312"
PGLOCATION = 'us-central1'
PGINSTANCE = "test-nl2sql"
PGDB = "test-db"
PGUSER = "postgres"
PGPWD = "test-nl2sql"
nl2sqlbq_client.init_pgdb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)

In [None]:
def get_text(resp: GenerationResponse):
    part = resp.candidates[0].content.parts[0]
    try:
        text = part.text
    except:
        text = None
    return text

def prior_sql_result()->str:
    # Execute the SQL and return the result
    return "Test output"


def call_api(name: str, args: str) -> str:
    if name == "prior_sql_tool":
        return prior_sql_result()


previous_sql_spec = {
    "name": "prior_sql_tool",
    "description": "Provides the SQL query that is generated for the previous question",
    "parameters": {
        "type": "object",
        "properties": {
            "question": {
                "type": "string",
                "description": "The natural language question for which the SQL query was generated"
            }
        },
        "required": [
            "question"
        ]
    }
}

sql_tools = Tool.from_dict(
    {
        "function_declarations":[previous_sql_spec]
    }
)

In [None]:
# from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel("gemini-1.0-pro")

table_chat = model.start_chat()
sql_chat = model.start_chat()

In [None]:
def get_chat_response(chat_model, prompt) -> str:
    responses = chat_model.send_message(prompt, stream=True)
    output = []
    for response in responses:
        output.append(response.candidates[0].content.parts[0].text)
    return "".join(output)

In [None]:
table_prompt = nl2sqlbq_client.table_filter_promptonly("Table identification initiation")
# print(table_prompt)
table_chat.send_message(table_prompt)

questions = ["How many people are enrolled in CalFresh?",
             "How many of them live in Los Angeles County?"
            ]

question = questions[0]


In [None]:
q_prompt_template = """Using the context in the chat history, identify the table name that is most probable to contain the data requested for the question given below.

Question: {question}
"""
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

In [None]:
for question in questions:
    q_prompt = q_prompt_template.format(question=question)
    table_identified = get_chat_response(table_chat, q_prompt)
    print("Table identified for queestion :'", question, "' is: ", table_identified)

    try:
        previous_question_sql = sql_chat.history[-1]
    except:
        previous_question_sql = ""

    sql_prompt = nl2sqlbq_client.generate_sql_few_shot_promptonly(question, table_name=table_identified, prev_sql=previous_question_sql)
    print(sql_prompt)
    sql_gen = get_chat_response(sql_chat, sql_prompt)
    print(sql_gen)




In [None]:
print(sql_chat.history[-1].text)

In [None]:
sql_resp = sql_chat.send_message(sql_prompt, tools=[sql_tools])
txt = get_text(sql_resp)
if txt:
    print("Response from chat", txt)
else:
    fname = sql_resp.candidates[0].content.parts[0].function_call.name
    print("Function to call", fname)
    fargs = sql_resp.candidates[0].content.parts[0].function_call.args
    print("function call arguments =", fargs)
    print(call_api(fname, "args"))

In [None]:
sql_chat.history