In [1]:
from embeddings.chroma_funcs import get_closest_entries, generate_knowledge_base_from_hf_dataset
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def format_rag_examples(examples):
    formatted_examples = "\n".join(
        f"""Example {j+1}:
### Input:
{example["question"]}

### Context:
{example["context"]}

### Response:
{example["answer"]}
""" for j, example in enumerate(examples)
    )
    
    return f"""Given the following examples:
{formatted_examples}"""


def generate_rag_sql_prompt(knowledge_base, data_point, n_examples):
    results = get_closest_entries(
        knowledge_base,
        data_point["question"],
        "question",
        n_results=n_examples,
    )
    rag_datapoints = results["metadatas"][0]
    formatted_examples = format_rag_examples(rag_datapoints)
    # print("formatted examples are: ", formatted_examples)
    inference_prompt = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.

{formatted_examples}
Please generate the SQL query that answers the following:
### Input:
{data_point["question"]}

### Context:
{data_point["context"]}

### Response:"""
    full_prompt = f"{inference_prompt}\n{data_point['answer']}"
    return full_prompt, inference_prompt

In [3]:
# train_datapoint = {"db_id":"department_management","context":"CREATE TABLE department (creation VARCHAR)","question":"In which year were most departments established?","answer":"SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","full_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:\nSELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","inference_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:"}
# full_prompt, inference_prompt = generate_rag_sql_prompt(train_datapoint)
# full_prompt

In [4]:
def add_prompt_features(example, knowledge_base, n_examples):
    # Add your logic to generate the extra feature here
    full_prompt, inference_prompt = generate_rag_sql_prompt(knowledge_base, example, n_examples)
    example['full_prompt'] = full_prompt
    example['inference_prompt'] = inference_prompt
    return example

def augment_dataset_with_prompts(dataset_name, knowledge_base, n_examples=5):
    # Load the dataset without specifying a split
    dataset_dict = load_dataset(dataset_name)

    # Iterate over each split in the loaded dataset
    for split, dataset in dataset_dict.items():
        dataset = dataset.map(
            lambda example: add_prompt_features(example, knowledge_base, n_examples=n_examples),
        )

        # Generate filename based on dataset name and split
        filename = f"{dataset_name.replace('/', '-')}-{split}-with-prompts.jsonl"
        
        # Save the dataset as a JSON file
        dataset.to_json(filename)

In [5]:
# so first we need to generate the knowledge_base
dataset_name = "samlhuillier/sql-create-context-spider-intersect"
knowledge_base = generate_knowledge_base_from_hf_dataset(dataset_name, "question")
augment_dataset_with_prompts(dataset_name, knowledge_base, n_examples=1)


Map:   7%|▋         | 258/3961 [00:23<05:43, 10.78 examples/s]


KeyboardInterrupt: 

In [17]:
test_datapoint = {
        "question": "What is the average horsepower for all cars produced before 1980 ?",
        "context": "CREATE TABLE cars_data (horsepower INTEGER, year INTEGER)",
        "answer": "select avg(horsepower) from cars_data where year  <  1980;",
        "db_id": "car_1"
    }

full_prompt, inference_prompt = generate_rag_sql_prompt(knowledge_base, test_datapoint, n_examples=2)
print(full_prompt)

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.

Given the following examples:
Example 1:
### Input:
How many vehicle in total?

### Context:
CREATE TABLE Vehicles (Id VARCHAR)

### Response:
SELECT count(*) FROM Vehicles;

Example 2:
### Input:
what is the average number of factories and maximum number of shops for manufacturers that opened before 1990.

### Context:
CREATE TABLE manufacturer (num_of_shops INTEGER, Num_of_Factories INTEGER, open_year INTEGER)

### Response:
SELECT max(num_of_shops) ,  avg(Num_of_Factories) FROM manufacturer WHERE open_year  <  1990

Please generate the SQL query that answers the following:
### Input:
What is the average horsepower for all cars produced before 1980 ?

### Context:
CREATE TABLE cars_data (horsepower INTEGER, year INTEGER)

### Response:
select avg(horsepower) from cars_data w