In [0]:
%pip install databricks-vectorsearch databricks-langchain
%restart_python

In [0]:
dbutils.widgets.text("catalog_name", "rohitb_demo")
dbutils.widgets.text("schema_name", "pdf_chat")
dbutils.widgets.text("table_name", "parsed_pdf_docs")
dbutils.widgets.text("chunked_table_name", "chunked_pdf_docs")
dbutils.widgets.text("vector_index_path", "chunked_pdf_docs_index")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
table_name = dbutils.widgets.get("table_name")
chunked_table_name = dbutils.widgets.get("chunked_table_name")
vector_index_path = dbutils.widgets.get("vector_index_path")

full_table_path = f"{catalog_name}.{schema_name}.{table_name}"
chunked_table_path = f"{catalog_name}.{schema_name}.{chunked_table_name}"

In [0]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pyspark.sql.functions import expr, col, explode
import pyspark.pandas as pd
import logging

def read_and_cast_table(spark, full_table_path):
    try:
        df = spark.read.table(full_table_path)
        df_casted = df.select(
            "path",
            expr("CAST(parsed:elements AS ARRAY<STRUCT<id: BIGINT, page_indices: ARRAY<BIGINT>, representation: STRUCT<text: STRING, markdown: STRING>, title: STRING, summary: STRING, type: STRING>>)").alias("elements")
        )
        return df_casted
    except Exception as e:
        logging.error(f"Error reading or casting table: {e}")
        raise

def explode_elements(df_casted):
    try:
        df_flat = df_casted.select("path", explode("elements").alias("el"))
        df_chunks = df_flat.select(
            col("path"),
            col("el.id").alias("element_id"),
            col("el.page_indices").getItem(0).alias("page"),
            col("el.title"),
            col("el.summary"),
            col("el.type"),
            col("el.representation.text").alias("text")
        ).filter(col("text").isNotNull())
        return df_chunks
    except Exception as e:
        logging.error(f"Error exploding elements: {e}")
        raise

def chunk_text_elements(pdf_df):
    try:
        splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        chunks = []
        for _, row in pdf_df.iterrows():
            splits = splitter.split_text(row['text'])
            for i, chunk in enumerate(splits):
                chunks.append({
                    "path": row['path'],
                    "element_id": row['element_id'],
                    "page": row['page'],
                    "type": row['type'],
                    "title": row['title'],
                    "summary": row['summary'],
                    "chunk_text": chunk,
                    "chunk_id": f"{row['path']}_e{row['element_id']}_c{i}"
                })
        chunk_df = pd.DataFrame(chunks)
        return chunk_df
    except Exception as e:
        logging.error(f"Error chunking text elements: {e}")
        raise

In [0]:
import logging

logging.basicConfig(level=logging.ERROR)

# Read and cast table
df_casted = read_and_cast_table(spark, full_table_path)

# Explode elements
df_chunks = explode_elements(df_casted)

# Convert to Pandas for chunking
pdf_df = df_chunks.toPandas()

# Chunk per element, preserve metadata
chunk_df = chunk_text_elements(pdf_df)

from pyspark.sql.functions import monotonically_increasing_id

chunk_sdf = chunk_df.to_spark()
chunk_sdf = chunk_sdf.withColumn("id", monotonically_increasing_id())
display(chunk_sdf)

In [0]:
chunk_sdf.write.mode("overwrite").saveAsTable(chunked_table_path)
spark.sql(f"ALTER TABLE {chunked_table_path} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

# Setup Vector Search Endpoint & a UC Function to retrive similar documents

In [0]:
from databricks.vector_search.client import VectorSearchClient

client = VectorSearchClient()
# client.list_endpoints()

In [0]:
vector_index = f"{catalog_name}.{schema_name}.{vector_index_path}"
try:
    index = client.create_delta_sync_index(
        endpoint_name="one-env-shared-endpoint-17",
        source_table_name=chunked_table_path,
        index_name=vector_index,
        pipeline_type="TRIGGERED",
        primary_key="id",
        embedding_source_column="chunk_text",
        embedding_model_endpoint_name="databricks-gte-large-en"
    )
except Exception as e:
    if "already exists" in str(e):
        index = client.get_index(index_name=vector_index)
    else:
        raise e

index.describe()

In [0]:
index.similarity_search(
    query_text="International roaming", 
    columns=["id", "chunk_text"]
)

## Create a Function in UC to query Vector Search (Optional)

In [0]:
%sql
CREATE OR REPLACE FUNCTION rohitb_demo.pdf_chat.pdf_chat_vector_search (
  -- The agent uses this comment to determine how to generate the query string parameter.
  query STRING
  COMMENT 'The query string for searching for answers from billing invoices.'
) RETURNS TABLE
-- Executes a search on a vector search index containing chunked text from various PDF files, specifically ATT billing invoices from different months. Each PDF has been parsed, and the chunked indexes are stored in the vector search index. This function retrieves relevant parts of the PDFs to assist an LLM in answering questions from the PDF.
COMMENT 'Executes a search on historical billing invoices to retrieve most relevant to the input query.' RETURN
SELECT
  chunk_text as page_content,
  map('doc_uri', path, 'chunk_id', chunk_id) as metadata
FROM
  vector_search(
    -- Specify your Vector Search index name here
    index => 'rohitb_demo.pdf_chat.chunked_pdf_docs_index',
    query => query,
    num_results => 5
  )