# Simple RAG Example using Vector Search and the Foundation Model API

[Databricks Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html) is a vector database built into Databricks that offers straightforward integration with the [Databricks Foundation Model API](https://docs.databricks.com/en/machine-learning/foundation-models/index.html) (FMAPI) embedding models.

Retrieval-augmented generation (RAG) is one of the most popular application architectures for creating natural-language interfaces for people to interact with an organization's data. This notebook builds a very simple RAG application, with the following steps:

1. Set up a vector index and configure it to automatically use an embedding model from the FMAPI to generate embeddings.
1. Load some text data into the vector database
1. Query the database
1. Build a prompt for an LLM from the query results
1. Query an LLM via the FMAPI, using that prompt

To learn more about how Databricks Vector Search works, see the documentation [here](https://docs.databricks.com/en/generative-ai/vector-search.html#how-does-vector-search-work).

For more details on querying models via the Foundation Model APIs, see the documentation [here](https://docs.databricks.com/en/machine-learning/model-serving/score-foundation-models.html#query-foundation-models).

## Setup
First, we will install the necessary libraries and set up a temporary catalog/schema/table for this example.

In [0]:
%pip install --upgrade --force-reinstall databricks-vectorsearch databricks-genai-inference
dbutils.library.restartPython()

### Define catalog, table, endpoint, and index names

In [0]:
CATALOG = "workspace"
DB='vs_demo'
SYMPTOMS_SOURCE_TABLE_NAME = "symptoms"
SYMPTOMS_SOURCE_TABLE_FULLNAME = f"{CATALOG}.{DB}.{SYMPTOMS_SOURCE_TABLE_NAME}"
PROCEDURES_SOURCE_TABLE_NAME = "procedures"
PROCEDURES_SOURCE_TABLE_FULLNAME = f"{CATALOG}.{DB}.{PROCEDURES_SOURCE_TABLE_NAME}"

### Create Catalog, Schema, and Table

A Databricks Vector Search Index is created from a Delta Table. The source Delta Table includes the data we ultimately want to index and search with the vector database. In this cell, we create the catalog, schema, and source table from which we will create the vector database.

In [0]:
# Set up schema/volume/table
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{DB}")
spark.sql(
    f"""CREATE TABLE IF NOT EXISTS {SYMPTOMS_SOURCE_TABLE_NAME} (
        id STRING,
        text STRING,
        date DATE,
        title STRING
    )
    USING delta 
    TBLPROPERTIES ('delta.enableChangeDataFeed' = 'true')
"""
)

In [0]:
# Set up procedures source table
spark.sql(
    f"""CREATE TABLE IF NOT EXISTS {PROCEDURES_SOURCE_TABLE_NAME} (
        id STRING,
        text STRING,
        date DATE,
        title STRING
    )
    USING delta 
    TBLPROPERTIES ('delta.enableChangeDataFeed' = 'true')
"""
)

## Set up the Vector Database
Next, we set up the vector database. There are three key steps:
1. Initialize the vector search client
2. Create the endpoint
3. Create the index using the source Delta table we created earlier and the `bge-large-en` embeddings model from the Foundation Model API

### Initialize the Vector Search Client

In [0]:
from databricks.vector_search.client import VectorSearchClient
vsc = VectorSearchClient()

### Create the Endpoint

The cell below will check if the endpoint already exists and create it if it does not.

In [0]:
VS_ENDPOINT_NAME = 'hackathon'

if vsc.list_endpoints().get('endpoints') == None or not VS_ENDPOINT_NAME in [endpoint.get('name') for endpoint in vsc.list_endpoints().get('endpoints')]:
    print(f"Creating new Vector Search endpoint named {VS_ENDPOINT_NAME}")
    vsc.create_endpoint(VS_ENDPOINT_NAME)
else:
    print(f"Endpoint {VS_ENDPOINT_NAME} already exists.")

vsc.wait_for_endpoint(VS_ENDPOINT_NAME, 600)

### Create the Vector Index

Now we can create the index over the Delta table we created earlier.

In [0]:
def create_index(*, VS_INDEX_NAME, source_table_fullname):
    VS_INDEX_FULLNAME = f"{CATALOG}.{DB}.{VS_INDEX_NAME}"

    if not VS_INDEX_FULLNAME in [index.get("name") for index in vsc.list_indexes(VS_ENDPOINT_NAME).get('vector_indexes', [])]:
        try:
            # set up an index with managed embeddings
            print("Creating Vector Index...")
            i = vsc.create_delta_sync_index_and_wait(
                endpoint_name=VS_ENDPOINT_NAME,
                index_name=VS_INDEX_FULLNAME,
                source_table_name=source_table_fullname,
                pipeline_type="TRIGGERED",
                primary_key="id",
                embedding_source_column="text",
                embedding_model_endpoint_name="databricks-bge-large-en"
            )
        except Exception as e:
            if "INTERNAL_ERROR" in str(e):
                # Check if the index exists after the error occurred
                if VS_INDEX_FULLNAME in [index.get("name") for index in vsc.list_indexes(VS_ENDPOINT_NAME).get('vector_indexes', [])]:
                    print(f"Index {VS_INDEX_FULLNAME} has been created.")
                else:
                    raise e
            else:
                raise e
    else:
        print(f"Index {VS_INDEX_FULLNAME} already exists.")    

 There are a few key points to note about the specific configuration we used in this case:
- We used `pipeline_type="TRIGGERED"`. This requires us to use the index's `sync()` method to manually sync the source Delta table with the index. We could, alternatively, use `pipeline_type="CONTINUOUS"` which will automatically keep the index in sync with the source table with only seconds of latency. This approach is more costly, though, as a compute cluster must be provisioned for the continuous sync streaming pipeline.
- We specified `embedding_model_endpoint_name="databricks-bge-large-en"`. We can use any embedding model available via model serving; this is the name of the pay-per-token Foundation Model API version of `databricks-bge-large-en`. By passing an `embedding_source_column` and `embedding_model_endpoint_name`, we configure the index such that it will automatically use the model to generate embeddings for the texts in the `text` column of the source table. We do not need to manually generate embeddings.

  If, however, we did want to manage embeddings manually, we could include the following arguments instead:

  ```
    embedding_vector_column="<embedding_column>",
    embedding_dimension=<embedding_dimension>
  ```

  In the latter approach, we include a column for embeddings in the source delta table and embeddings are *not* computed automatically from the text column.

## Set up some example texts

Now we set up some example texts to index.

In [0]:
%sql
-- Not sure why this was needed, but create_index threw an error about data feed must be enabled
ALTER TABLE workspace.vs_demo.symptoms SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

In [0]:
%sql
ALTER TABLE workspace.vs_demo.procedures SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

In [0]:
SYMPTOMS_INDEX = "symptoms_index_demo_small"
VS_SYMPTOMS_INDEX_FULLNAME = f"{CATALOG}.{DB}.{SYMPTOMS_INDEX}"
create_index(VS_INDEX_NAME=SYMPTOMS_INDEX, source_table_fullname=SYMPTOMS_SOURCE_TABLE_FULLNAME)

In [0]:
PROCEDURES_INDEX = "procedures_index_demo_small"
VS_PROCEDURES_INDEX_FULLNAME = f"{CATALOG}.{DB}.{PROCEDURES_INDEX}"
create_index(VS_INDEX_NAME=PROCEDURES_INDEX, source_table_fullname=PROCEDURES_SOURCE_TABLE_FULLNAME)

### Chunk the texts
Typically, when using a vector database for retrieval-augmented generation (RAG) tasks, we break the texts apart into smaller (and sometimes overlapping) chunks in order to return focused and relevant information without returning an excessive amount of text.

In the code below, we break the sample texts above into shorter overlapping text chunks.

In [0]:
import re

def chunk_text(text, chunk_size, overlap):
    words = text.split()
    chunks = []
    index = 0

    while index < len(words):
        end = index + chunk_size
        while end < len(words) and not re.match(r'.*[.!?]\s*$', words[end]):
            end += 1
        chunk = ' '.join(words[index:end+1])
        chunks.append(chunk)
        index += chunk_size - overlap

    return chunks

def make_chunks(documents):
    chunks = []

    for document in documents:
        print("document:" + repr(document))
        for i, c in enumerate(chunk_text(document["text"], 150, 25)):
            chunk = {}
            chunk["text"] = c
            chunk["title"] = document["title"]
            chunk["date"] = document["date"]
            chunk["id"] = document["title"] + "_" + str(i)

            chunks.append(chunk)
    return chunks

## Insert the text chunks into the source delta table

Now we save the chunks, along with some metadata (a document title, date, and a unique id) to the source delta table.

In [0]:
from pyspark.sql import functions as F
from datetime import datetime
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, DateType

schema = StructType(
    [
        StructField("id", StringType(), True),
        StructField("text", StringType(), True),
        StructField("title", StringType(), True),
        StructField("date", DateType(), True),
    ]
)

def write_chunks_to_table(*, table, chunks):
    if chunks:
        result_df = spark.createDataFrame(chunks, schema=schema)
        result_df.write.format("delta").mode("append").saveAsTable(
            table
        )

def chunk_and_write(*, src_tbl, src_title_col, src_txt_col, dst_tbl):
    documents = []

    df = table(src_tbl).select(
        F.col(src_title_col).alias("title"),
        F.col(src_txt_col).alias("text"),
    )
    valid_df = (df
                .where(F.col("text").isNotNull())
                .where(F.col("title").isNotNull())
    )
    rows = valid_df.collect()
    for row in rows:
        d = row.asDict()
        documents.append({
            "text": d["text"],
            "title": d["title"], 
            "date": datetime.strptime("2024-01-16", "%Y-%m-%d"),
        })
    chunks = make_chunks(documents)
    print("num chunks: " + repr(len(chunks)))
    write_chunks_to_table(table=dst_tbl, chunks=chunks)
    print(f"chunks written to table {dst_tbl}")

In [0]:
%sql
DELETE from workspace.vs_demo.symptoms;
DELETE from workspace.vs_demo.procedures;

In [0]:
chunk_and_write(src_tbl="workspace.default.aienrichmentbriefer",
                src_title_col="CODE",
                src_txt_col="conditionDescriptionAndSymptomsAI",
#                 src_title_col="code",
#                 src_txt_col="LONG DESCRIPTION (VALID ICD-10 FY2024)",
                dst_tbl=SYMPTOMS_SOURCE_TABLE_FULLNAME)

In [0]:
chunk_and_write(src_tbl="workspace.default.procedure_enhanced",
                src_title_col="code",
                src_txt_col="longer_description",
                dst_tbl=PROCEDURES_SOURCE_TABLE_FULLNAME)

## Sync the Vector Search Index
Because we specified `pipeline_type="TRIGGERED"` when configuring the Delta Index, we still need to manually tell the index to sync with the delta table. This will take a few minutes.

This will not work if the index is not ready yet. We use the `wait_until_ready` method to wait until the index is ready.

In [0]:
VS_ENDPOINT_NAME

In [0]:
# Sync symptoms
symptoms_index = vsc.get_index(endpoint_name=VS_ENDPOINT_NAME,
                      index_name=VS_SYMPTOMS_INDEX_FULLNAME)
symptoms_index.sync()

In [0]:
# Sync procedures
procedures_index = vsc.get_index(endpoint_name=VS_ENDPOINT_NAME,
                      index_name=VS_PROCEDURES_INDEX_FULLNAME)
procedures_index.sync()

## Query the Vector Indexes

Now that we have added our text chunks to the source delta table and synced it with the Vector Search index, we're ready to query the index! We do this with the `index.similarity_search()` method.

The `columns` argument takes a list of the columns we want returned; in this case, we request the text and title columns.

**NOTE**: If the cell below does not return any results, wait a couple of minutes and try again. The index may still be syncing.

In [0]:
chats = [
    ["I have pains in the upper right abdomen, and I have been losing weight. What could it be?"],
    ["I have abdominal pain. Can it be fever?"],
]

In [0]:
# symptoms query
question = chats[0][0]
symptoms_index.similarity_search(columns=["text", "title"],
                        query_text=question,
                        num_results = 3, score_threshold=0.4)


In [0]:
# procedures query
procedures_index.similarity_search(columns=["text", "title"],
                        query_text=question,
                        num_results = 3,
                        score_threshold=0.4)


## Answering Questions without context

In [0]:
from databricks_genai_inference import ChatSession

chat = ChatSession(model="databricks-meta-llama-3-70b-instruct",
                   system_message="You are a helpful hospital assistant.",
                   max_tokens=128)

In [0]:
question = chats[0][0]
print(f"question: {question}\n")
chat.reply(question)
print(chat.last)

# Chat with context


Now let's see what kind of reply we get when we provide context from vector search.

In [0]:
# reset history
chat = ChatSession(model="databricks-meta-llama-3-70b-instruct",
                   system_message="You are a helpful hospital assistant. Answer the user's question based on the provided context.",
                   max_tokens=128)

def get_context(*, index, question):
    # get context from vector search
    raw_context = index.similarity_search(columns=["text", "title"],
                           query_text=question,
                        num_results = 3,
                        score_threshold=0.4)
    context_string = ""
    for (i,doc) in enumerate(raw_context.get('result').get('data_array')):
        context_string += f"Retrieved context {i+1}:\n"
        context_string += doc[0]
        context_string += "\n\n"

question = chats[0][0]
symptoms_context = get_context(index=symptoms_index, question=question)
# procedures_context = get_context(index=procedures_index, question=question)
procedures_context = ""
print(f"question: {question}\n")
chat.reply(f"User question: {question}\n\nContext: {symptoms_context}\n\n{procedures_context}")
print(chat.last)

It is now able to answer based on the provided context.

### Congratulations! Demo complete.


# Additional information

## Using the UI
Most of the Vector Database management steps above can be done via the UI: you can [create an endpoint](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint-using-the-ui), [create an index](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-index-using-the-ui), sync the index, and more via the UI in the Databricks Catalog Explorer.

## Experimenting in the AI Playground
The [Databricks AI Playground](https://docs.databricks.com/en/large-language-models/ai-playground.html) provides a GUI for quickly experimenting with LLMs available via the FMAPI, enabling you to compare the outputs of those models and determine which model best serves your needs.