# The Notebook Will Create a Vector Search Index based on Fine-tuned Adapter Model and Perform Retriever Evaluate

In [0]:
%pip install --quiet databricks-sdk==0.24.0
%pip install databricks-vectorsearch tiktoken
%pip install -U mlflow
%pip install python-snappy==0.7.3
%pip install torch==2.4.0 torchvision==0.19.0
dbutils.library.restartPython()

In [0]:
dbutils.widgets.text(name="target_catalog", label="Catalog", defaultValue="dev_catalog")
dbutils.widgets.text(name="target_schema", label="Schema", defaultValue="dev_schema")
dbutils.widgets.dropdown(name="source_table", label="Table", defaultValue="chunks", 
                         choices=["explore_chunks", "all_fulltext", "long_chunks", "short_chunks"])
dbutils.widgets.text(name="embedding_model_endpoint", label="Embedding Model Endpoint", defaultValue="snowflake-arctic-embed-m-long-linear-adapter")

In [0]:
catalog = dbutils.widgets.get("target_catalog")
schema = dbutils.widgets.get("target_schema")
table = dbutils.widgets.get("source_table")
embedding_model_endpoint = dbutils.widgets.get("embedding_model_endpoint")
print(f"Parameters: target_catalog: {catalog}, target_schema: {schema}, source_table: {table}, embedding_model_endpoint: {embedding_model_endpoint}")

# Prepare Vector Search Index

Use high_quality_chunk tables

In [0]:
from databricks.vector_search.client import VectorSearchClient
import time

vsc = VectorSearchClient()

def endpoint_exists(vsc, vs_endpoint_name):
  try:
    return vs_endpoint_name in [e['name'] for e in vsc.list_endpoints().get('endpoints', [])]
  except Exception as e:
    raise e


def wait_for_vs_endpoint_to_be_ready(vsc, vs_endpoint_name):
  for i in range(180):
    try:
      endpoint = vsc.get_endpoint(vs_endpoint_name)
    except Exception as e:
      #Temp fix for potential REQUEST_LIMIT_EXCEEDED issue
      if "REQUEST_LIMIT_EXCEEDED" in str(e):
        print("WARN: couldn't get endpoint status due to REQUEST_LIMIT_EXCEEDED error. Please manually check your endpoint status")
        return
      else:
        raise e
    status = endpoint.get("endpoint_status", endpoint.get("status"))["state"].upper()
    if "ONLINE" in status:
      return endpoint
    elif "PROVISIONING" in status or i <6:
      if i % 20 == 0: 
        print(f"Waiting for endpoint to be ready, this can take a few min... {endpoint}")
      time.sleep(10)
    else:
      raise Exception(f'''Error with the endpoint {vs_endpoint_name}. - this shouldn't happen: {endpoint}.\n Please delete it and re-run the previous cell: vsc.delete_endpoint("{vs_endpoint_name}")''')
  raise Exception(f"Timeout, your endpoint isn't ready yet: {vsc.get_endpoint(vs_endpoint_name)}")


def index_exists(vsc, endpoint_name, index_full_name):
    try:
        vsc.get_index(endpoint_name, index_full_name).describe()
        return True
    except Exception as e:
        if 'RESOURCE_DOES_NOT_EXIST' not in str(e):
            print(f'Unexpected error describing the index. This could be a permission issue.')
            raise e
    return False


def wait_for_index_to_be_ready(vsc, vs_endpoint_name, index_name):
  for i in range(180):
    idx = vsc.get_index(vs_endpoint_name, index_name).describe()
    index_status = idx.get('status', idx.get('index_status', {}))
    status = index_status.get('detailed_state', index_status.get('status', 'UNKNOWN')).upper()
    url = index_status.get('index_url', index_status.get('url', 'UNKNOWN'))
    if "ONLINE" in status:
      return
    if "UNKNOWN" in status:
      print(f"Can't get the status - will assume index is ready {idx} - url: {url}")
      return
    elif "PROVISIONING" in status:
      if i % 40 == 0: print(f"Waiting for index to be ready, this can take a few min... {index_status} - pipeline url:{url}")
      time.sleep(10)
    else:
        raise Exception(f'''Error with the index - this shouldn't happen. DLT pipeline might have been killed.\n Please delete it and re-run the previous cell: vsc.delete_index("{index_name}, {vs_endpoint_name}") \nIndex details: {idx}''')
  raise Exception(f"Timeout, your index isn't ready yet: {vsc.get_index(index_name, vs_endpoint_name)}")

In [0]:
VECTOR_SEARCH_ENDPOINT_NAME = "embedding_optimization"

if not endpoint_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME):
    vsc.create_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME, endpoint_type="STANDARD")

wait_for_vs_endpoint_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME)
print(f"Endpoint named {VECTOR_SEARCH_ENDPOINT_NAME} is ready.")

In [0]:
from databricks.sdk import WorkspaceClient
import databricks.sdk.service.catalog as c

# source table
source_table_fullname = f"{catalog}.{schema}.{table}_high_quality_chunks"

# index name
vs_index_fullname = f"{catalog}.{schema}.{table}_high_quality_ft_chunks_vs_index"

if not index_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname):
  print(f"Creating index {vs_index_fullname} on endpoint {VECTOR_SEARCH_ENDPOINT_NAME}...")
  vsc.create_delta_sync_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
    index_name=vs_index_fullname,
    source_table_name=source_table_fullname,
    pipeline_type="TRIGGERED",
    primary_key="id",
    embedding_source_column='text',
    embedding_model_endpoint_name=embedding_model_endpoint
  )
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
else:
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
  vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).sync()

print(f"index {vs_index_fullname} on table {source_table_fullname} is ready")

# Evaluate Fine-tuned Adapter Embedding Model and Compare to Base Embedding Model

In [0]:
def get_relevant_documents(question : str, index_name : str, k : int = 3, filters : str = None, max_retries : int = 3) -> List[dict]:
    response_received = False
    retries = 0
    while ((response_received == False) and (retries < max_retries)):
        try:
            docs = w.vector_search_indexes.query_index(
                index_name=index_name,
                columns=["uuid","content","category","filepath"],
                filters_json=filters,
                num_results=k,
                query_text=question
            )
            response_received = True
            docs_pd = pd.DataFrame(docs.result.data_array)
            docs_pd.columns = [_c.name for _c in docs.manifest.columns]
        except Exception as e:
            retries += 1
            time.sleep(1 * retries)
            print(e)
    return json.loads(docs_pd.to_json(orient="records"))

In [0]:
index_full_new = f"{catalog}.{schema}.{table}_high_quality_ft_chunks_vs_index"
index_full_old = f"{catalog}.{schema}.{table}_high_quality_chunks_vs_index"

def get_relevant_doc_ids(question : str, index_name : str) -> list[str]:
    docs = get_relevant_documents(question, index_name=index_name, k=10)
    return [_x["id"] for _x in docs]

In [0]:
eval_df = spark.table(f"{catalog}.{schema}.{table}_quality_chunk_question_context_pairs")

eval_pd_new = eval_df.toPandas()
eval_pd_new["id"] = eval_pd_new["id"].transform(lambda x: [x])
eval_pd_new["retrieved_docs"] = eval_pd_new["generated_question"].transform(lambda x: get_relevant_doc_ids(x, index_full_new))

eval_pd_old = eval_df.toPandas()
eval_pd_old["id"] = eval_pd_old["id"].transform(lambda x: [x])
eval_pd_old["retrieved_docs"] = eval_pd_old["generated_question"].transform(lambda x: get_relevant_doc_ids(x, index_full_old))

display(eval_pd_new)

In [0]:
with mlflow.start_run() as run:
    eval_results_ft = mlflow.evaluate(
        data=eval_pd_new,
        model_type="retriever",
        targets="uuid",
        predictions="retrieved_docs",
        evaluators="default",
        extra_metrics=[mlflow.metrics.recall_at_k(i) for i in range(1,10,1)] + [mlflow.metrics.precision_at_k(i) for i in range(1,10,1)]
    )
    eval_results = mlflow.evaluate(
        data=eval_pd_old,
        model_type="retriever",
        targets="uuid",
        predictions="retrieved_docs",
        evaluators="default",
        extra_metrics=[mlflow.metrics.recall_at_k(i) for i in range(1,10,1)] + [mlflow.metrics.precision_at_k(i) for i in range(1,10,1)]
    )

In [0]:
from matplotlib import pyplot as plt

plt.plot([eval_results_ft.metrics[f"recall_at_{i}/mean"] for i in range(1,10,1)], label="adapter_finetuned")
plt.plot([eval_results.metrics[f"recall_at_{i}/mean"] for i in range(1,10,1)], label="baseline")
plt.title("Recall at k")
plt.xlabel("k")
plt.legend()
plt.show()

In [0]:
plt.plot([eval_results_ft.metrics[f"precision_at_{i}/mean"] for i in range(1,10,1)], label="adapter_finetuned")
plt.plot([eval_results.metrics[f"precision_at_{i}/mean"] for i in range(1,10,1)], label="baseline")
plt.title("Recall at k")
plt.xlabel("k")
plt.legend()
plt.show()

In [0]:
def f1_score(precision, recall):
    return 2 * (precision * recall) / (precision + recall)

plt.plot([f1_score(eval_results_ft.metrics[f"recall_at_{i}/mean"], eval_results.metrics[f"precision_at_{i}/mean"]) 
          for i in range(1,10,1)], label="F1")
plt.plot([f1_score(eval_results.metrics[f"recall_at_{i}/mean"], eval_results.metrics[f"precision_at_{i}/mean"]) 
          for i in range(1,10,1)], label="F1")
plt.title("F1 at k")
plt.xlabel("k")
plt.legend()
plt.show()