# This Notebook Perform the following steps:

1. Read a Policy PDF document from a UC volume
2. Parsed the PDF documents with PyPDF into text
3. Perform Chunking
4. Create a vector search index from the chunks

In [0]:
%pip install -U --quiet databricks-sdk==0.28.0 databricks-vectorsearch
%pip install --quiet pypdf==4.1.0 tiktoken langchain-text-splitters==0.2.2
%pip install transformers==4.41.1 torch==2.3.0 --quiet
dbutils.library.restartPython()

In [0]:
%run ./config

In [0]:
dbutils.widgets.text(name="table_prefix", label="Source Table", defaultValue="policy_docs")

In [0]:
table_prefix = dbutils.widgets.get("table_prefix")
print(f"table_prefix: {table_prefix}")

In [0]:
spark.sql(f"USE CATALOG {catalog};")
spark.sql(f"USE SCHEMA {schema};")

# Create a Vector Search Index from PDF Files in a Volume Diretory

In [0]:
tables_config = {
    "raw_files_table_name": f"{table_prefix}_raw_files",
    "parsed_files_table_name": f"{table_prefix}_parsed_files",
    "chunked_files_table_name": f"{table_prefix}_chunked_files"
}

embedding_config = {
    "embedding_endpoint_name": "databricks-gte-large-en",
    "embedding_tokenizer": {
            "tokenizer_model_name": "Alibaba-NLP/gte-large-en-v1.5",
            "tokenizer_source": "hugging_face",
        },
}

chunker_config = {
    "name": "langchain_recursive_char",
    "config": {
        "chunk_size_tokens": 1024,
        "chunk_overlap_tokens": 256
    }
}

## Load Raw PDF from the UC Volume

In [0]:
# Load the raw riles
SOURCE_PATH = f"/Volumes/{catalog}/{schema}/{volume_name_policies}/policy_doc"

raw_files_df = (
    spark.read.format("binaryFile")
    .option("recursiveFileLookup", "true")
    .option("pathGlobFilter", f"*.pdf")
    .load(SOURCE_PATH)
)

# Save to a table
raw_files_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    tables_config["raw_files_table_name"]
)

# reload to get correct lineage in UC
raw_files_df = spark.read.table(tables_config["raw_files_table_name"])

# For debugging, show the list of files, but hide the binary content
display(raw_files_df.drop("content"))

# Check that files were present and loaded
if raw_files_df.count() == 0:
    display(
        f"`{SOURCE_PATH}` does not contain any files.  Open the volume and upload at least file."
    )
    raise Exception(f"`{SOURCE_PATH}` does not contain any files.")

## Parse PDF with PyPDF

In [0]:
from pypdf import PdfReader
from typing import TypedDict, Dict
import warnings
import io 
from pyspark.sql.functions import udf, col, md5, explode
from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType

In [0]:
class ParserReturnValue(TypedDict):
    doc_parsed_contents: Dict[str, str]
    parser_status: str

def parse_bytes_pypdf(
    raw_doc_contents_bytes: bytes,
) -> ParserReturnValue:
    try:
        pdf = io.BytesIO(raw_doc_contents_bytes)
        reader = PdfReader(pdf)

        parsed_content = [page_content.extract_text() for page_content in reader.pages]
        output = {
            "num_pages": str(len(parsed_content)),
            "parsed_content": "\n".join(parsed_content).replace("Allstate", "Autosure"),
        }

        return {
            "doc_parsed_contents": output,
            "parser_status": "SUCCESS",
        }
    except Exception as e:
        warnings.warn(f"Exception {e} has been thrown during parsing")
        return {
            "doc_parsed_contents": {"num_pages": "", "parsed_content": ""},
            "parser_status": f"ERROR: {e}",
        }

# Create UDF
parser_udf = udf(
    parse_bytes_pypdf,
    returnType=StructType(
        [
            StructField(
                "doc_parsed_contents",
                MapType(StringType(), StringType()),
                nullable=True,
            ),
            StructField("parser_status", StringType(), nullable=True),
        ]
    ),
)

In [0]:
# Run the parsing
parsed_files_staging_df = raw_files_df.withColumn("parsing", parser_udf("content")).drop("content")


# Check and warn on any errors
errors_df = parsed_files_staging_df.filter(
    col(f"parsing.parser_status") != "SUCCESS"
)

num_errors = errors_df.count()
if num_errors > 0:
    print(f"{num_errors} documents had parse errors.  Please review.")
    display(errors_df)

# Filter for successfully parsed files
parsed_files_df = parsed_files_staging_df.filter(parsed_files_staging_df.parsing.parser_status == "SUCCESS").withColumn("doc_parsed_contents", col("parsing.doc_parsed_contents")).drop("parsing")

# Write to Delta Table
parsed_files_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(tables_config["parsed_files_table_name"])

# reload to get correct lineage in UC
parsed_files_df = spark.table(tables_config["parsed_files_table_name"])

# Display for debugging
print(f"Parsed {parsed_files_df.count()} documents.")

display(parsed_files_df)

## Chunk the parsed text

In [0]:
from functools import partial
from langchain_text_splitters import RecursiveCharacterTextSplitter
import tiktoken

In [0]:
class ChunkerReturnValue(TypedDict):
    chunked_text: str
    chunker_status: str


def chunk_parsed_content_langrecchar(
    doc_parsed_contents: str, chunk_size: int, chunk_overlap: int, embedding_config
) -> ChunkerReturnValue:
    try:
        tokenizer = tiktoken.encoding_name_for_model('text-embedding-3-large')
        text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            tokenizer,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )

        chunks = text_splitter.split_text(doc_parsed_contents)
        return {
            "chunked_text": [doc for doc in chunks],
            "chunker_status": "SUCCESS",
        }
    except Exception as e:
        warnings.warn(f"Exception {e} has been thrown during parsing")
        return {
            "chunked_text": [],
            "chunker_status": f"ERROR: {e}",
        }


chunker_udf = udf(
    partial(
        chunk_parsed_content_langrecchar,
        chunk_size=chunker_config.get('config').get("chunk_size_tokens"),
        chunk_overlap=chunker_config.get('config').get("chunk_overlap_tokens"),
        embedding_config=embedding_config,
    ),
    returnType=StructType(
        [
            StructField("chunked_text", ArrayType(StringType()), nullable=True),
            StructField("chunker_status", StringType(), nullable=True),
        ]
    ),
)

In [0]:
# Run the chunker
chunked_files_df = parsed_files_df.withColumn(
    "chunked",
    chunker_udf("doc_parsed_contents.parsed_content"),
)

# Check and warn on any errors
errors_df = chunked_files_df.filter(chunked_files_df.chunked.chunker_status != "SUCCESS")

num_errors = errors_df.count()
if num_errors > 0:
    print(f"{num_errors} chunks had parse errors.  Please review.")
    display(errors_df)

# Filter for successful chunks
chunked_files_df = chunked_files_df.filter(chunked_files_df.chunked.chunker_status == "SUCCESS").select(
    "path",
    explode("chunked.chunked_text").alias("chunked_text"),
    md5(col("chunked_text")).alias("chunk_id")
)

# Write to Delta Table
chunked_files_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    tables_config["chunked_files_table_name"]
)

%md
## Create a Vector Search Index from Chunked Text

In [0]:
from databricks.vector_search.client import VectorSearchClient
vsc = VectorSearchClient(disable_notice=True)

if not endpoint_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME):
    vsc.create_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME, endpoint_type="STANDARD")

wait_for_vs_endpoint_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME)
print(f"Endpoint named {VECTOR_SEARCH_ENDPOINT_NAME} is ready.")

In [0]:
# To enable this table as the source of vector search index, we need to enable CDF
spark.sql(f"ALTER TABLE {catalog}.{schema}.{tables_config['chunked_files_table_name']} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

In [0]:
from databricks.sdk import WorkspaceClient
import databricks.sdk.service.catalog as c

#The table we'd like to index
source_table_fullname = f"{catalog}.{schema}.{tables_config['chunked_files_table_name']}"
# Where we want to store our index
vs_index_fullname = f"{catalog}.{schema}.{tables_config['chunked_files_table_name']}_vs_index"

if not index_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname):
  print(f"Creating index {vs_index_fullname} on endpoint {VECTOR_SEARCH_ENDPOINT_NAME}...")
  vsc.create_delta_sync_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
    index_name=vs_index_fullname,
    source_table_name=source_table_fullname,
    pipeline_type="TRIGGERED",
    primary_key="chunk_id",
    embedding_source_column='chunked_text', #The column containing our text
    embedding_model_endpoint_name='databricks-bge-large-en' #The embedding endpoint used to create the embeddings
  )
  #Let's wait for the index to be ready and all our embeddings to be created and indexed
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
else:
  #Trigger a sync to update our vs content with the new data saved in the table
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
  vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).sync()

print(f"index {vs_index_fullname} on table {source_table_fullname} is ready")