In [1]:
from embeddings.chroma_funcs import get_closest_entries, generate_knowledge_base_from_hf_dataset, get_random_entries
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def format_rag_examples(examples):
    if len(examples) == 1:
        return f"""
Given the following example:
### Input:
{examples[0]["question"]}

### Context:
{examples[0]["context"]}

### Response:
{examples[0]["answer"]}
"""
    formatted_examples = "\n".join(
        f"""Example {j+1}:
### Input:
{example["question"]}

### Context:
{example["context"]}

### Response:
{example["answer"]}
""" for j, example in enumerate(examples)
    )
    
    return f"""
Given the following examples:
{formatted_examples}"""

def get_examples(knowledge_base, data_point, n_examples, randomize=False):
    formatted_examples = ""
    if n_examples > 0:
        if randomize:
            formatted_examples = get_random_entries(knowledge_base, n_examples)
        else:
            formatted_examples = get_closest_entries(
                knowledge_base,
                data_point["question"],
                "question",
                n_results=n_examples,
            )
        formatted_examples = format_rag_examples(formatted_examples["metadatas"][0])
    return formatted_examples

def generate_rag_sql_prompt(knowledge_base, data_point, n_examples, randomize=False):
    formatted_examples = get_examples(knowledge_base, data_point, n_examples, randomize)

    inference_prompt = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.
{formatted_examples}
Please generate the SQL query that answers the following:
### Input:
{data_point["question"]}

### Context:
{data_point["context"]}

### Response:"""
    full_prompt = f"{inference_prompt}\n{data_point['answer']}"
    return full_prompt, inference_prompt

In [3]:
# train_datapoint = {"db_id":"department_management","context":"CREATE TABLE department (creation VARCHAR)","question":"In which year were most departments established?","answer":"SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","full_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:\nSELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1","inference_prompt":"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.\n\nGiven the following examples:\nExample 1:\n### Input:\nList the creation year, name and budget of each department.\n\n### Context:\nCREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n\n### Response:\nSELECT creation ,  name ,  budget_in_billions FROM department\n\nExample 2:\n### Input:\nlist names of all departments ordered by their names.\n\n### Context:\nCREATE TABLE department (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM department ORDER BY dept_name\n\nExample 3:\n### Input:\nWhat are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?\n\n### Context:\nCREATE TABLE department (creation VARCHAR, department_id VARCHAR); CREATE TABLE management (department_id VARCHAR, head_id VARCHAR); CREATE TABLE head (head_id VARCHAR, born_state VARCHAR)\n\n### Response:\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'\n\nExample 4:\n### Input:\nWhat is the number of departments in Division \"AS\"?\n\n### Context:\nCREATE TABLE DEPARTMENT (Division VARCHAR)\n\n### Response:\nSELECT count(*) FROM DEPARTMENT WHERE Division  =  \"AS\"\n\nExample 5:\n### Input:\nFind the names of the top 3 departments that provide the largest amount of courses?\n\n### Context:\nCREATE TABLE course (dept_name VARCHAR)\n\n### Response:\nSELECT dept_name FROM course GROUP BY dept_name ORDER BY count(*) DESC LIMIT 3\n\nPlease generate the SQL query that answers the following:\n### Input:\nIn which year were most departments established?\n\n### Context:\nCREATE TABLE department (creation VARCHAR)\n\n### Response:"}
# full_prompt, inference_prompt = generate_rag_sql_prompt(train_datapoint)
# full_prompt

In [3]:
def add_prompt_features(example, knowledge_base, n_examples, randomize=False):
    # Add your logic to generate the extra feature here
    full_prompt, inference_prompt = generate_rag_sql_prompt(knowledge_base, example, n_examples, randomize)
    example['full_prompt'] = full_prompt
    example['inference_prompt'] = inference_prompt
    return example

def augment_dataset_with_prompts(dataset_name, knowledge_base, n_examples=5, randomize=False):
    dataset_dict = load_dataset(dataset_name)

    for split, dataset in dataset_dict.items():
        dataset = dataset.map(
            lambda example: add_prompt_features(example, knowledge_base, n_examples=n_examples, randomize=randomize),
        )

        # TODO: add in embedding function:
        filename = f"{dataset_name.replace('/', '-')}-{split}-with-{n_examples}-examples-random-{randomize}.jsonl"
        
        # Save the dataset as a JSON file
        dataset.to_json(filename)

In [5]:
from chromadb.utils import embedding_functions

openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                api_key="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",
                model_name="text-embedding-ada-002"
            )

print(openai_ef._model_name)
abc = openai_ef(["hell owrld", ])
abc
default_ef = embedding_functions.DefaultEmbeddingFunction()

print(default_ef.model)

text-embedding-ada-002
None


In [6]:
# so first we need to generate the knowledge_base
dataset_name = "samlhuillier/sql-create-context-spider-intersect"
knowledge_base = generate_knowledge_base_from_hf_dataset(dataset_name, "question", default_ef)
# augment_dataset_with_prompts(dataset_name, knowledge_base, n_examples=1, randomize=True)


embd_fn <chromadb.utils.embedding_functions.ONNXMiniLM_L6_V2 object at 0x7f78ef69cf40>


Map:   0%|          | 0/3961 [00:00<?, ? examples/s]Add of existing embedding ID: id1
Add of existing embedding ID: id2
Add of existing embedding ID: id3
Add of existing embedding ID: id4
Add of existing embedding ID: id5
Add of existing embedding ID: id6
Add of existing embedding ID: id7
Add of existing embedding ID: id8
Add of existing embedding ID: id9
Add of existing embedding ID: id10
Add of existing embedding ID: id11
Add of existing embedding ID: id12
Add of existing embedding ID: id13
Add of existing embedding ID: id14
Add of existing embedding ID: id15
Add of existing embedding ID: id16
Add of existing embedding ID: id17
Add of existing embedding ID: id18
Add of existing embedding ID: id19
Add of existing embedding ID: id20
Add of existing embedding ID: id21
Add of existing embedding ID: id22
Add of existing embedding ID: id23
Add of existing embedding ID: id24
Add of existing embedding ID: id25
Add of existing embedding ID: id26
Add of existing embedding ID: id27
Add of exist

In [None]:
# dataset = load_dataset(dataset_name, split="validation")
# def dataset_chunks(dataset, n):
#     """Yield successive n-sized chunks from the dataset."""
#     for i in range(0, len(dataset), n):
#         start_idx = i
#         end_idx = min(i + n, len(dataset))
#         yield dataset.select(range(start_idx, end_idx))


# # Chunk size
# chunk_size = 300
# created_chunks = dataset_chunks(dataset, chunk_size)
# print(created_chunks)
# for idx, chunk in enumerate(dataset_chunks(dataset, chunk_size)):
#     print(chunk)

In [None]:
test_datapoint = {
        "question": "What is the average horsepower for all cars produced before 1980 ?",
        "context": "CREATE TABLE cars_data (horsepower INTEGER, year INTEGER)",
        "answer": "select avg(horsepower) from cars_data where year  <  1980;",
        "db_id": "car_1"
    }

full_prompt, inference_prompt = generate_rag_sql_prompt(knowledge_base, test_datapoint, n_examples=2)
print(full_prompt)