In [1]:
from embeddings.chroma_funcs import get_closest_entries, generate_knowledge_base_from_hf_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# so first we need to generate the knowledge_base
dataset_name = "samlhuillier/sql-create-context-spider-intersect"
knowledge_base = generate_knowledge_base_from_hf_dataset(dataset_name, "question")

In [45]:
def format_rag_examples(examples):
    formatted_examples = "\n".join(
        f"""Example {j+1}:
### Input:
{example["question"]}

### Context:
{example["context"]}

### Response:
{example["answer"]}
""" for j, example in enumerate(examples)
    )
    
    return f"""Given the following examples:
{formatted_examples}"""


def generate_rag_sql_prompt(data_point):
    results = get_closest_entries(
        knowledge_base,
        data_point["question"],
        "question",
    )
    print("results are: ", results)
    rag_datapoints = results["metadatas"][0]
    formatted_examples = format_rag_examples(rag_datapoints)
    # print("formatted examples are: ", formatted_examples)
    inference_prompt = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.

{formatted_examples}
Please generate the SQL query that answers the following:
### Input:
{data_point["question"]}

### Context:
{data_point["context"]}

### Response:"""
    full_prompt = f"{inference_prompt}\n{data_point['answer']}"
    return full_prompt, inference_prompt

In [46]:
test_datapoint = {
        "question": "What is the average horsepower for all cars produced before 1980 ?",
        "context": "CREATE TABLE cars_data (horsepower INTEGER, year INTEGER)",
        "answer": "select avg(horsepower) from cars_data where year  <  1980;",
        "db_id": "car_1"
    }

full_prompt, inference_prompt = generate_rag_sql_prompt(test_datapoint)

results are:  {'ids': [['id3766', 'id2025', 'id3032', 'id471', 'id21']], 'distances': [[1.0678924322128296, 1.075580358505249, 1.1425307989120483, 1.1635199785232544, 1.1649410724639893]], 'metadatas': [[{'answer': 'SELECT count(*) FROM Vehicles;', 'context': 'CREATE TABLE Vehicles (Id VARCHAR)', 'db_id': 'driving_school', 'question': 'How many vehicle in total?'}, {'answer': 'SELECT max(num_of_shops) ,  avg(Num_of_Factories) FROM manufacturer WHERE open_year  <  1990', 'context': 'CREATE TABLE manufacturer (num_of_shops INTEGER, Num_of_Factories INTEGER, open_year INTEGER)', 'db_id': 'manufacturer', 'question': 'what is the average number of factories and maximum number of shops for manufacturers that opened before 1990.'}, {'answer': 'SELECT avg(T1.price) ,  T2.name FROM products AS T1 JOIN manufacturers AS T2 ON T1.Manufacturer  =  T2.code GROUP BY T2.name', 'context': 'CREATE TABLE Manufacturers (name VARCHAR, code VARCHAR); CREATE TABLE products (Price INTEGER, manufacturer VARCHA

In [47]:
full_prompt

'You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nHow many vehicle in total?\n\n### Context:\nCREATE TABLE Vehicles (Id VARCHAR)\n\n### Response:\nSELECT count(*) FROM Vehicles;\n\nExample 2:\n### Input:\nwhat is the average number of factories and maximum number of shops for manufacturers that opened before 1990.\n\n### Context:\nCREATE TABLE manufacturer (num_of_shops INTEGER, Num_of_Factories INTEGER, open_year INTEGER)\n\n### Response:\nSELECT max(num_of_shops) ,  avg(Num_of_Factories) FROM manufacturer WHERE open_year  <  1990\n\nExample 3:\n### Input:\nWhat are the average prices of products for each manufacturer?\n\n### Context:\nCREATE TABLE Manufacturers (name VARCHAR, code VARCHAR); CREATE TABLE products (Price INTEGER, manufacturer VARCHAR)\n\n### Response:\

In [48]:
train_datapoint = {"db_id":"department_management","context":"CREATE TABLE department (creation VARCHAR)","question":"In which year were most departments established?","answer":"SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","full_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:\nSELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","inference_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:"}
full_prompt, inference_prompt = generate_rag_sql_prompt(train_datapoint)
full_prompt

results are:  {'ids': [['id3', 'id1948', 'id7', 'id2695', 'id799']], 'distances': [[0.7234521508216858, 0.7478238344192505, 0.7558315396308899, 0.7632932662963867, 0.7800575494766235]], 'metadatas': [[{'answer': 'SELECT creation ,  name ,  budget_in_billions FROM department', 'context': 'CREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)', 'db_id': 'department_management', 'question': 'List the creation year, name and budget of each department.'}, {'answer': 'SELECT dept_name FROM department ORDER BY dept_name', 'context': 'CREATE TABLE department (dept_name VARCHAR)', 'db_id': 'college_1', 'question': 'list names of all departments ordered by their names.'}, {'answer': "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'", 'context': 'CREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE 

'You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state \'Alabama\'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VA